Skip to content

Commit c3215fc

Browse files
committed
[TRTLLM-6633][feat] Padding for piecewise cudagraph
Signed-off-by: Jin Li <[email protected]>
1 parent 2bb90ba commit c3215fc

File tree

9 files changed

+320
-215
lines changed

9 files changed

+320
-215
lines changed

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
748748
{
749749
m.def(
750750
"merge_chunked_attention_for_mla("
751-
"Tensor merged_attn"
751+
"Tensor(a!) merged_attn"
752752
", Tensor temp_attn"
753753
", Tensor merged_softmax_stats"
754754
", Tensor temp_softmax_stats"

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class AttentionMetadata:
135135
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
136136
_num_tokens: int = field(init=False, default=0, repr=False)
137137

138+
# The number of tokens in the padded sequence.
139+
padded_num_tokens: Optional[int] = None
140+
138141
# This buffer is currently only used for TrtllmAttentionMetadata.
139142
cache_indirection: Optional[torch.Tensor] = None
140143

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from tensorrt_llm.llmapi.utils import enable_llm_debug
1313

14-
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
15-
make_weak_ref)
14+
from ..utils import (get_model_extra_attrs,
15+
get_per_request_piecewise_cuda_graph_flag,
16+
get_piecewise_cuda_graph_flag, make_weak_ref)
1617
from .multi_stream.auto_multi_stream import multi_stream_schedule
1718
from .utils import get_capture_piecewise_cuda_graph_flag, is_call_function
1819

@@ -154,8 +155,10 @@ def __call__(self, *args):
154155
elif isinstance(self.compile_time_num_tokens, int):
155156
runtime_num_of_token = self.compile_time_num_tokens
156157

157-
if runtime_num_of_token is None or runtime_num_of_token not in self.entries or not get_piecewise_cuda_graph_flag(
158-
):
158+
if (runtime_num_of_token is None
159+
or runtime_num_of_token not in self.entries
160+
or not get_piecewise_cuda_graph_flag()
161+
or not get_per_request_piecewise_cuda_graph_flag()):
159162
return self.default_callable(*args)
160163

161164
entry = self.entries[runtime_num_of_token]

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ def forward(
419419
**kwargs,
420420
)
421421

422+
if attn_metadata.padded_num_tokens is not None:
423+
hidden_states = hidden_states[:attn_metadata.num_tokens]
424+
422425
if self.draft_model is not None:
423426
# get logits
424427
logits = self.logits_processor.forward(
@@ -427,9 +430,20 @@ def forward(
427430
attn_metadata,
428431
True,
429432
)
433+
mtp_input_ids = input_ids
434+
mtp_position_ids = position_ids
435+
if attn_metadata.padded_num_tokens is not None:
436+
if input_ids is not None:
437+
# Slice along the first dimension
438+
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
439+
if position_ids is not None:
440+
# Slice along the last dimension
441+
mtp_position_ids = position_ids[:, :attn_metadata.
442+
num_tokens]
443+
430444
# get accepted tokens and next draft tokens
431-
return self.spec_worker(input_ids=input_ids,
432-
position_ids=position_ids,
445+
return self.spec_worker(input_ids=mtp_input_ids,
446+
position_ids=mtp_position_ids,
433447
hidden_states=hidden_states,
434448
logits=logits,
435449
attn_metadata=attn_metadata,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 58 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,19 @@ def _attn_impl(
338338
attention_sinks: Optional[torch.Tensor] = None,
339339
):
340340

341+
padded_num_tokens = attn_metadata.padded_num_tokens
342+
num_tokens = attn_metadata.num_tokens
343+
344+
if padded_num_tokens is not None:
345+
assert q.shape[0] == padded_num_tokens
346+
q = q[:num_tokens, :]
347+
if k is not None:
348+
assert k.shape[0] == padded_num_tokens
349+
k = k[:num_tokens, :]
350+
if v is not None:
351+
assert v.shape[0] == padded_num_tokens
352+
v = v[:num_tokens, :]
353+
341354
out_scale = None
342355
out_scale_sf = None
343356
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
@@ -368,7 +381,7 @@ def _attn_impl(
368381
attention_window_size=attention_window_size,
369382
attention_mask_data=attention_mask_data,
370383
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
371-
output=output,
384+
output=output[:num_tokens, :] if output is not None else None,
372385
output_sf=output_sf,
373386
attention_sinks=attention_sinks)
374387
if isinstance(attn_output, tuple):
@@ -937,11 +950,10 @@ def create_output(self, hidden_states: torch.Tensor):
937950
return hidden_states.new_empty([num_tokens, hidden_size],
938951
dtype=hidden_states.dtype)
939952

940-
def forward_impl(self,
941-
position_ids: Optional[torch.Tensor],
953+
def forward_impl(self, position_ids: Optional[torch.Tensor],
942954
hidden_states: torch.Tensor,
943955
attn_metadata: AttentionMetadata,
944-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
956+
output: torch.Tensor) -> None:
945957
"""
946958
Forward pass for the MLA module.
947959
@@ -954,6 +966,18 @@ def forward_impl(self,
954966
Returns:
955967
torch.Tensor: The output tensor.
956968
"""
969+
# split q, k, v into context and gen batches
970+
num_contexts = attn_metadata.num_contexts
971+
num_generations = attn_metadata.num_generations
972+
num_ctx_tokens = attn_metadata.num_ctx_tokens
973+
num_tokens = attn_metadata.num_tokens
974+
padded_num_tokens = attn_metadata.padded_num_tokens
975+
976+
if padded_num_tokens is not None:
977+
hidden_states = hidden_states[:num_tokens, ...]
978+
if position_ids is not None:
979+
position_ids = position_ids[:num_tokens, ...]
980+
957981
if self.is_lite:
958982
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
959983
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
@@ -981,15 +1005,11 @@ def forward_impl(self,
9811005
self.aux_stream,
9821006
)
9831007

984-
# split q, k, v into context and gen batches
985-
num_contexts = attn_metadata.num_contexts
986-
num_generations = attn_metadata.num_generations
987-
num_ctx_tokens = attn_metadata.num_ctx_tokens
988-
num_tokens = attn_metadata.num_tokens
989-
9901008
assert q.shape[
9911009
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
9921010

1011+
assert output is not None, "output must be provided"
1012+
9931013
if num_contexts > 0:
9941014
q_ctx = q[:num_ctx_tokens, ...]
9951015
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
@@ -999,17 +1019,14 @@ def forward_impl(self,
9991019
assert position_ids is not None
10001020
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
10011021

1002-
attn_output_context = self.forward_context(
1022+
self.forward_context(
10031023
q_ctx,
10041024
compressed_kv_ctx,
10051025
k_pe_ctx,
10061026
attn_metadata,
1027+
output[:num_ctx_tokens, :],
10071028
latent_cache_ctx,
1008-
output=output if num_generations == 0 else None)
1009-
if num_generations == 0:
1010-
return attn_output_context
1011-
else:
1012-
attn_output_context = None
1029+
)
10131030

10141031
if num_generations > 0:
10151032
q_gen = q[num_ctx_tokens:, ...]
@@ -1020,39 +1037,14 @@ def forward_impl(self,
10201037
assert position_ids is not None
10211038
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
10221039

1023-
attn_output_gen = self.forward_generation(
1040+
self.forward_generation(
10241041
q_gen,
10251042
compressed_kv_gen,
10261043
k_pe_gen,
10271044
attn_metadata,
1045+
output[num_ctx_tokens:num_tokens, :],
10281046
latent_cache_gen,
1029-
output=output if num_contexts == 0 else None)
1030-
if num_contexts == 0:
1031-
return attn_output_gen
1032-
else:
1033-
attn_output_gen = None
1034-
1035-
# release pytorch activation memory
1036-
q = None
1037-
compressed_kv = None
1038-
k_pe = None
1039-
1040-
assert attn_output_context is not None and attn_output_gen is not None
1041-
assert (
1042-
len(attn_output_context.shape) == 2
1043-
), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}"
1044-
assert (
1045-
len(attn_output_gen.shape) == 2
1046-
), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}"
1047-
output = output if output is not None else torch.empty(
1048-
(num_tokens, attn_output_context.shape[1]),
1049-
dtype=attn_output_context.dtype,
1050-
device=attn_output_context.device)
1051-
output[:attn_output_context.shape[0], :] = attn_output_context
1052-
output[attn_output_context.shape[0]:, :] = attn_output_gen
1053-
attn_output_context = None
1054-
attn_output_gen = None
1055-
return output
1047+
)
10561048

10571049
def _maybe_concat_qkv(self, q, k, v):
10581050
if k is not None and v is not None and self.support_fused_qkv:
@@ -1061,13 +1053,14 @@ def _maybe_concat_qkv(self, q, k, v):
10611053
return q, k, v
10621054

10631055
def forward_context_default(
1064-
self,
1065-
q: torch.Tensor,
1066-
compressed_kv: torch.Tensor,
1067-
k_pe: torch.Tensor,
1068-
attn_metadata: AttentionMetadata,
1069-
latent_cache: Optional[torch.Tensor] = None,
1070-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1056+
self,
1057+
q: torch.Tensor,
1058+
compressed_kv: torch.Tensor,
1059+
k_pe: torch.Tensor,
1060+
attn_metadata: AttentionMetadata,
1061+
output: torch.Tensor,
1062+
latent_cache: Optional[torch.Tensor] = None,
1063+
) -> torch.Tensor:
10711064
kv = self.kv_b_proj(compressed_kv)
10721065
k_nope, v = kv.split(
10731066
[
@@ -1109,7 +1102,7 @@ def forward_context_with_cached_kv(
11091102
q: torch.Tensor,
11101103
latent_cache: torch.Tensor,
11111104
attn_metadata: AttentionMetadata,
1112-
output: Optional[torch.Tensor] = None,
1105+
output: torch.Tensor,
11131106
) -> torch.Tensor:
11141107
assert latent_cache is not None
11151108
trtllm_attention = cast(TrtllmAttention, self.mha)
@@ -1195,7 +1188,7 @@ def forward_context_with_chunked_prefill(
11951188
latent_cache: torch.
11961189
Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
11971190
attn_metadata: TrtllmAttentionMetadata,
1198-
output: Optional[torch.Tensor] = None,
1191+
output: torch.Tensor,
11991192
) -> torch.Tensor:
12001193
trtllm_attention = cast(TrtllmAttention, self.mha)
12011194
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1218,11 +1211,8 @@ def forward_context_with_chunked_prefill(
12181211
dtype=torch.float,
12191212
device='cuda',
12201213
)
1221-
if output is None:
1222-
attn_output = q.new_empty(
1223-
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
1224-
else:
1225-
attn_output = output
1214+
1215+
attn_output = output
12261216
temp_attn_output = q.new_empty(
12271217
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
12281218

@@ -1354,8 +1344,8 @@ def forward_context(
13541344
compressed_kv: torch.Tensor,
13551345
k_pe: torch.Tensor,
13561346
attn_metadata: AttentionMetadata,
1347+
output: torch.Tensor,
13571348
latent_cache: Optional[torch.Tensor] = None,
1358-
output: Optional[torch.Tensor] = None,
13591349
) -> torch.Tensor:
13601350
if isinstance(self.mha, TrtllmAttention):
13611351
assert isinstance(attn_metadata, TrtllmAttentionMetadata)
@@ -1368,16 +1358,17 @@ def forward_context(
13681358
return self.forward_context_with_cached_kv(
13691359
q, latent_cache, attn_metadata, output)
13701360
return self.forward_context_default(q, compressed_kv, k_pe,
1371-
attn_metadata, latent_cache, output)
1361+
attn_metadata, output, latent_cache)
13721362

13731363
def forward_generation(
1374-
self,
1375-
q: torch.Tensor,
1376-
compressed_kv: torch.Tensor,
1377-
k_pe: torch.Tensor,
1378-
attn_metadata: AttentionMetadata,
1379-
latent_cache: Optional[torch.Tensor] = None,
1380-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1364+
self,
1365+
q: torch.Tensor,
1366+
compressed_kv: torch.Tensor,
1367+
k_pe: torch.Tensor,
1368+
attn_metadata: AttentionMetadata,
1369+
output: torch.Tensor,
1370+
latent_cache: Optional[torch.Tensor] = None,
1371+
) -> torch.Tensor:
13811372
num_tokens = q.shape[0]
13821373
q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split(
13831374
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
@@ -1449,12 +1440,6 @@ def forward_generation(
14491440
attn_out_latent = attn_out_latent.view(
14501441
[-1, self.num_heads, self.kv_lora_rank])
14511442

1452-
# [seq, num_heads * v_head_dim]
1453-
output = output if output is not None else torch.empty(
1454-
[num_tokens, self.num_heads * self.v_head_dim],
1455-
dtype=attn_out_latent.dtype,
1456-
device=attn_out_latent.device)
1457-
14581443
attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim])
14591444

14601445
if self.v_b_proj.dtype == torch.bfloat16:

0 commit comments

Comments
 (0)