Skip to content

Commit 0282354

Browse files
authored
[TRTLLM-6633][feat] Padding for piecewise cudagraph (#6750)
Signed-off-by: Jin Li <[email protected]>
1 parent 87d1d3a commit 0282354

File tree

10 files changed

+324
-218
lines changed

10 files changed

+324
-218
lines changed

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
565565
{
566566
m.def(
567567
"merge_chunked_attention_for_mla("
568-
"Tensor merged_attn"
568+
"Tensor(a!) merged_attn"
569569
", Tensor temp_attn"
570570
", Tensor merged_softmax_stats"
571571
", Tensor temp_softmax_stats"

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class AttentionMetadata:
135135
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
136136
_num_tokens: int = field(init=False, default=0, repr=False)
137137

138+
# The number of tokens in the padded sequence.
139+
padded_num_tokens: Optional[int] = None
140+
138141
# This buffer is currently only used for TrtllmAttentionMetadata.
139142
cache_indirection: Optional[torch.Tensor] = None
140143

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
4949
self.rank = tensorrt_llm.mpi_rank()
5050
self.enable_inductor = enable_inductor
51-
self.capture_num_tokens = capture_num_tokens or []
51+
self.capture_num_tokens = sorted(capture_num_tokens or [])
5252
self.piecewise_cuda_graph = enable_piecewise_cuda_graph
5353
self.no_optimization = False
5454
# We only need to create aux streams.

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from tensorrt_llm.llmapi.utils import enable_llm_debug
1313

14-
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
15-
make_weak_ref)
14+
from ..utils import (get_model_extra_attrs,
15+
get_per_request_piecewise_cuda_graph_flag,
16+
get_piecewise_cuda_graph_flag, make_weak_ref)
1617
from .multi_stream.auto_multi_stream import multi_stream_schedule
1718
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
1819

@@ -154,8 +155,10 @@ def __call__(self, *args):
154155
elif isinstance(self.compile_time_num_tokens, int):
155156
runtime_num_of_token = self.compile_time_num_tokens
156157

157-
if runtime_num_of_token is None or runtime_num_of_token not in self.entries or not get_piecewise_cuda_graph_flag(
158-
):
158+
if (runtime_num_of_token is None
159+
or runtime_num_of_token not in self.entries
160+
or not get_piecewise_cuda_graph_flag()
161+
or not get_per_request_piecewise_cuda_graph_flag()):
159162
return self.default_callable(*args)
160163

161164
entry = self.entries[runtime_num_of_token]

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,9 @@ def forward(
425425
**kwargs,
426426
)
427427

428+
if attn_metadata.padded_num_tokens is not None:
429+
hidden_states = hidden_states[:attn_metadata.num_tokens]
430+
428431
if self.draft_model is not None:
429432
# get logits
430433
logits = self.logits_processor.forward(
@@ -433,9 +436,20 @@ def forward(
433436
attn_metadata,
434437
True,
435438
)
439+
mtp_input_ids = input_ids
440+
mtp_position_ids = position_ids
441+
if attn_metadata.padded_num_tokens is not None:
442+
if input_ids is not None:
443+
# Slice along the first dimension
444+
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
445+
if position_ids is not None:
446+
# Slice along the last dimension
447+
mtp_position_ids = position_ids[:, :attn_metadata.
448+
num_tokens]
449+
436450
# get accepted tokens and next draft tokens
437-
return self.spec_worker(input_ids=input_ids,
438-
position_ids=position_ids,
451+
return self.spec_worker(input_ids=mtp_input_ids,
452+
position_ids=mtp_position_ids,
439453
hidden_states=hidden_states,
440454
logits=logits,
441455
attn_metadata=attn_metadata,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 58 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,19 @@ def _attn_impl(
338338
attention_sinks: Optional[torch.Tensor] = None,
339339
):
340340

341+
padded_num_tokens = attn_metadata.padded_num_tokens
342+
num_tokens = attn_metadata.num_tokens
343+
344+
if padded_num_tokens is not None:
345+
assert q.shape[0] == padded_num_tokens
346+
q = q[:num_tokens, :]
347+
if k is not None:
348+
assert k.shape[0] == padded_num_tokens
349+
k = k[:num_tokens, :]
350+
if v is not None:
351+
assert v.shape[0] == padded_num_tokens
352+
v = v[:num_tokens, :]
353+
341354
out_scale = None
342355
out_scale_sf = None
343356
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
@@ -368,7 +381,7 @@ def _attn_impl(
368381
attention_window_size=attention_window_size,
369382
attention_mask_data=attention_mask_data,
370383
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
371-
output=output,
384+
output=output[:num_tokens, :] if output is not None else None,
372385
output_sf=output_sf,
373386
attention_sinks=attention_sinks)
374387
if isinstance(attn_output, tuple):
@@ -936,11 +949,10 @@ def create_output(self, hidden_states: torch.Tensor):
936949
return hidden_states.new_empty([num_tokens, hidden_size],
937950
dtype=hidden_states.dtype)
938951

939-
def forward_impl(self,
940-
position_ids: Optional[torch.Tensor],
952+
def forward_impl(self, position_ids: Optional[torch.Tensor],
941953
hidden_states: torch.Tensor,
942954
attn_metadata: AttentionMetadata,
943-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
955+
output: torch.Tensor) -> None:
944956
"""
945957
Forward pass for the MLA module.
946958
@@ -953,6 +965,18 @@ def forward_impl(self,
953965
Returns:
954966
torch.Tensor: The output tensor.
955967
"""
968+
# split q, k, v into context and gen batches
969+
num_contexts = attn_metadata.num_contexts
970+
num_generations = attn_metadata.num_generations
971+
num_ctx_tokens = attn_metadata.num_ctx_tokens
972+
num_tokens = attn_metadata.num_tokens
973+
padded_num_tokens = attn_metadata.padded_num_tokens
974+
975+
if padded_num_tokens is not None:
976+
hidden_states = hidden_states[:num_tokens, ...]
977+
if position_ids is not None:
978+
position_ids = position_ids[:num_tokens, ...]
979+
956980
if self.is_lite:
957981
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
958982
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
@@ -980,15 +1004,11 @@ def forward_impl(self,
9801004
self.aux_stream,
9811005
)
9821006

983-
# split q, k, v into context and gen batches
984-
num_contexts = attn_metadata.num_contexts
985-
num_generations = attn_metadata.num_generations
986-
num_ctx_tokens = attn_metadata.num_ctx_tokens
987-
num_tokens = attn_metadata.num_tokens
988-
9891007
assert q.shape[
9901008
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
9911009

1010+
assert output is not None, "output must be provided"
1011+
9921012
if num_contexts > 0:
9931013
q_ctx = q[:num_ctx_tokens, ...]
9941014
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
@@ -998,17 +1018,14 @@ def forward_impl(self,
9981018
assert position_ids is not None
9991019
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
10001020

1001-
attn_output_context = self.forward_context(
1021+
self.forward_context(
10021022
q_ctx,
10031023
compressed_kv_ctx,
10041024
k_pe_ctx,
10051025
attn_metadata,
1026+
output[:num_ctx_tokens, :],
10061027
latent_cache_ctx,
1007-
output=output if num_generations == 0 else None)
1008-
if num_generations == 0:
1009-
return attn_output_context
1010-
else:
1011-
attn_output_context = None
1028+
)
10121029

10131030
if num_generations > 0:
10141031
q_gen = q[num_ctx_tokens:, ...]
@@ -1019,48 +1036,24 @@ def forward_impl(self,
10191036
assert position_ids is not None
10201037
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
10211038

1022-
attn_output_gen = self.forward_generation(
1039+
self.forward_generation(
10231040
q_gen,
10241041
compressed_kv_gen,
10251042
k_pe_gen,
10261043
attn_metadata,
1044+
output[num_ctx_tokens:num_tokens, :],
10271045
latent_cache_gen,
1028-
output=output if num_contexts == 0 else None)
1029-
if num_contexts == 0:
1030-
return attn_output_gen
1031-
else:
1032-
attn_output_gen = None
1033-
1034-
# release pytorch activation memory
1035-
q = None
1036-
compressed_kv = None
1037-
k_pe = None
1038-
1039-
assert attn_output_context is not None and attn_output_gen is not None
1040-
assert (
1041-
len(attn_output_context.shape) == 2
1042-
), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}"
1043-
assert (
1044-
len(attn_output_gen.shape) == 2
1045-
), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}"
1046-
output = output if output is not None else torch.empty(
1047-
(num_tokens, attn_output_context.shape[1]),
1048-
dtype=attn_output_context.dtype,
1049-
device=attn_output_context.device)
1050-
output[:attn_output_context.shape[0], :] = attn_output_context
1051-
output[attn_output_context.shape[0]:, :] = attn_output_gen
1052-
attn_output_context = None
1053-
attn_output_gen = None
1054-
return output
1046+
)
10551047

10561048
def forward_context_default(
1057-
self,
1058-
q: torch.Tensor,
1059-
compressed_kv: torch.Tensor,
1060-
k_pe: torch.Tensor,
1061-
attn_metadata: AttentionMetadata,
1062-
latent_cache: Optional[torch.Tensor] = None,
1063-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1049+
self,
1050+
q: torch.Tensor,
1051+
compressed_kv: torch.Tensor,
1052+
k_pe: torch.Tensor,
1053+
attn_metadata: AttentionMetadata,
1054+
output: torch.Tensor,
1055+
latent_cache: Optional[torch.Tensor] = None,
1056+
) -> torch.Tensor:
10641057
kv = self.kv_b_proj(compressed_kv)
10651058
k_nope, v = kv.split(
10661059
[
@@ -1099,7 +1092,7 @@ def forward_context_with_cached_kv(
10991092
q: torch.Tensor,
11001093
latent_cache: torch.Tensor,
11011094
attn_metadata: AttentionMetadata,
1102-
output: Optional[torch.Tensor] = None,
1095+
output: torch.Tensor,
11031096
) -> torch.Tensor:
11041097
assert latent_cache is not None
11051098
trtllm_attention = cast(TrtllmAttention, self.mha)
@@ -1168,7 +1161,7 @@ def forward_context_with_chunked_prefill(
11681161
latent_cache: torch.
11691162
Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
11701163
attn_metadata: TrtllmAttentionMetadata,
1171-
output: Optional[torch.Tensor] = None,
1164+
output: torch.Tensor,
11721165
) -> torch.Tensor:
11731166
trtllm_attention = cast(TrtllmAttention, self.mha)
11741167
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1190,11 +1183,8 @@ def forward_context_with_chunked_prefill(
11901183
dtype=torch.float,
11911184
device='cuda',
11921185
)
1193-
if output is None:
1194-
attn_output = q.new_empty(
1195-
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
1196-
else:
1197-
attn_output = output
1186+
1187+
attn_output = output
11981188
temp_attn_output = q.new_empty(
11991189
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
12001190

@@ -1332,8 +1322,8 @@ def forward_context(
13321322
compressed_kv: torch.Tensor,
13331323
k_pe: torch.Tensor,
13341324
attn_metadata: AttentionMetadata,
1325+
output: torch.Tensor,
13351326
latent_cache: Optional[torch.Tensor] = None,
1336-
output: Optional[torch.Tensor] = None,
13371327
) -> torch.Tensor:
13381328
if isinstance(self.mha, TrtllmAttention):
13391329
assert isinstance(attn_metadata, TrtllmAttentionMetadata)
@@ -1346,16 +1336,17 @@ def forward_context(
13461336
return self.forward_context_with_cached_kv(
13471337
q, latent_cache, attn_metadata, output)
13481338
return self.forward_context_default(q, compressed_kv, k_pe,
1349-
attn_metadata, latent_cache, output)
1339+
attn_metadata, output, latent_cache)
13501340

13511341
def forward_generation(
1352-
self,
1353-
q: torch.Tensor,
1354-
compressed_kv: torch.Tensor,
1355-
k_pe: torch.Tensor,
1356-
attn_metadata: AttentionMetadata,
1357-
latent_cache: Optional[torch.Tensor] = None,
1358-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1342+
self,
1343+
q: torch.Tensor,
1344+
compressed_kv: torch.Tensor,
1345+
k_pe: torch.Tensor,
1346+
attn_metadata: AttentionMetadata,
1347+
output: torch.Tensor,
1348+
latent_cache: Optional[torch.Tensor] = None,
1349+
) -> torch.Tensor:
13591350
num_tokens = q.shape[0]
13601351
q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split(
13611352
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
@@ -1427,12 +1418,6 @@ def forward_generation(
14271418
attn_out_latent = attn_out_latent.view(
14281419
[-1, self.num_heads, self.kv_lora_rank])
14291420

1430-
# [seq, num_heads * v_head_dim]
1431-
output = output if output is not None else torch.empty(
1432-
[num_tokens, self.num_heads * self.v_head_dim],
1433-
dtype=attn_out_latent.dtype,
1434-
device=attn_out_latent.device)
1435-
14361421
attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim])
14371422

14381423
if self.v_b_proj.dtype == torch.bfloat16:

0 commit comments

Comments
 (0)