@@ -338,6 +338,19 @@ def _attn_impl(
338
338
attention_sinks : Optional [torch .Tensor ] = None ,
339
339
):
340
340
341
+ padded_num_tokens = attn_metadata .padded_num_tokens
342
+ num_tokens = attn_metadata .num_tokens
343
+
344
+ if padded_num_tokens is not None :
345
+ assert q .shape [0 ] == padded_num_tokens
346
+ q = q [:num_tokens , :]
347
+ if k is not None :
348
+ assert k .shape [0 ] == padded_num_tokens
349
+ k = k [:num_tokens , :]
350
+ if v is not None :
351
+ assert v .shape [0 ] == padded_num_tokens
352
+ v = v [:num_tokens , :]
353
+
341
354
out_scale = None
342
355
out_scale_sf = None
343
356
has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
@@ -368,7 +381,7 @@ def _attn_impl(
368
381
attention_window_size = attention_window_size ,
369
382
attention_mask_data = attention_mask_data ,
370
383
enable_attn_nvfp4_output = enable_attn_nvfp4_output ,
371
- output = output ,
384
+ output = output [: num_tokens , :] if output is not None else None ,
372
385
output_sf = output_sf ,
373
386
attention_sinks = attention_sinks )
374
387
if isinstance (attn_output , tuple ):
@@ -936,11 +949,10 @@ def create_output(self, hidden_states: torch.Tensor):
936
949
return hidden_states .new_empty ([num_tokens , hidden_size ],
937
950
dtype = hidden_states .dtype )
938
951
939
- def forward_impl (self ,
940
- position_ids : Optional [torch .Tensor ],
952
+ def forward_impl (self , position_ids : Optional [torch .Tensor ],
941
953
hidden_states : torch .Tensor ,
942
954
attn_metadata : AttentionMetadata ,
943
- output : Optional [ torch .Tensor ] = None ) -> torch . Tensor :
955
+ output : torch .Tensor ) -> None :
944
956
"""
945
957
Forward pass for the MLA module.
946
958
@@ -953,6 +965,18 @@ def forward_impl(self,
953
965
Returns:
954
966
torch.Tensor: The output tensor.
955
967
"""
968
+ # split q, k, v into context and gen batches
969
+ num_contexts = attn_metadata .num_contexts
970
+ num_generations = attn_metadata .num_generations
971
+ num_ctx_tokens = attn_metadata .num_ctx_tokens
972
+ num_tokens = attn_metadata .num_tokens
973
+ padded_num_tokens = attn_metadata .padded_num_tokens
974
+
975
+ if padded_num_tokens is not None :
976
+ hidden_states = hidden_states [:num_tokens , ...]
977
+ if position_ids is not None :
978
+ position_ids = position_ids [:num_tokens , ...]
979
+
956
980
if self .is_lite :
957
981
compressed_kv , k_pe = self .kv_a_proj_with_mqa (hidden_states ).split (
958
982
[self .kv_lora_rank , self .qk_rope_head_dim ], - 1 )
@@ -980,15 +1004,11 @@ def forward_impl(self,
980
1004
self .aux_stream ,
981
1005
)
982
1006
983
- # split q, k, v into context and gen batches
984
- num_contexts = attn_metadata .num_contexts
985
- num_generations = attn_metadata .num_generations
986
- num_ctx_tokens = attn_metadata .num_ctx_tokens
987
- num_tokens = attn_metadata .num_tokens
988
-
989
1007
assert q .shape [
990
1008
0 ] == num_tokens , f"Expect q.shape[0] to be { num_tokens } , but got { q .shape [0 ]} "
991
1009
1010
+ assert output is not None , "output must be provided"
1011
+
992
1012
if num_contexts > 0 :
993
1013
q_ctx = q [:num_ctx_tokens , ...]
994
1014
compressed_kv_ctx = compressed_kv [:num_ctx_tokens , ...]
@@ -998,17 +1018,14 @@ def forward_impl(self,
998
1018
assert position_ids is not None
999
1019
k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
1000
1020
1001
- attn_output_context = self .forward_context (
1021
+ self .forward_context (
1002
1022
q_ctx ,
1003
1023
compressed_kv_ctx ,
1004
1024
k_pe_ctx ,
1005
1025
attn_metadata ,
1026
+ output [:num_ctx_tokens , :],
1006
1027
latent_cache_ctx ,
1007
- output = output if num_generations == 0 else None )
1008
- if num_generations == 0 :
1009
- return attn_output_context
1010
- else :
1011
- attn_output_context = None
1028
+ )
1012
1029
1013
1030
if num_generations > 0 :
1014
1031
q_gen = q [num_ctx_tokens :, ...]
@@ -1019,48 +1036,24 @@ def forward_impl(self,
1019
1036
assert position_ids is not None
1020
1037
k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
1021
1038
1022
- attn_output_gen = self .forward_generation (
1039
+ self .forward_generation (
1023
1040
q_gen ,
1024
1041
compressed_kv_gen ,
1025
1042
k_pe_gen ,
1026
1043
attn_metadata ,
1044
+ output [num_ctx_tokens :num_tokens , :],
1027
1045
latent_cache_gen ,
1028
- output = output if num_contexts == 0 else None )
1029
- if num_contexts == 0 :
1030
- return attn_output_gen
1031
- else :
1032
- attn_output_gen = None
1033
-
1034
- # release pytorch activation memory
1035
- q = None
1036
- compressed_kv = None
1037
- k_pe = None
1038
-
1039
- assert attn_output_context is not None and attn_output_gen is not None
1040
- assert (
1041
- len (attn_output_context .shape ) == 2
1042
- ), f"attn_output_context must be rank 2, not { len (attn_output_context .shape )} "
1043
- assert (
1044
- len (attn_output_gen .shape ) == 2
1045
- ), f"attn_output_gen must be rank 2, not { len (attn_output_gen .shape )} "
1046
- output = output if output is not None else torch .empty (
1047
- (num_tokens , attn_output_context .shape [1 ]),
1048
- dtype = attn_output_context .dtype ,
1049
- device = attn_output_context .device )
1050
- output [:attn_output_context .shape [0 ], :] = attn_output_context
1051
- output [attn_output_context .shape [0 ]:, :] = attn_output_gen
1052
- attn_output_context = None
1053
- attn_output_gen = None
1054
- return output
1046
+ )
1055
1047
1056
1048
def forward_context_default (
1057
- self ,
1058
- q : torch .Tensor ,
1059
- compressed_kv : torch .Tensor ,
1060
- k_pe : torch .Tensor ,
1061
- attn_metadata : AttentionMetadata ,
1062
- latent_cache : Optional [torch .Tensor ] = None ,
1063
- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1049
+ self ,
1050
+ q : torch .Tensor ,
1051
+ compressed_kv : torch .Tensor ,
1052
+ k_pe : torch .Tensor ,
1053
+ attn_metadata : AttentionMetadata ,
1054
+ output : torch .Tensor ,
1055
+ latent_cache : Optional [torch .Tensor ] = None ,
1056
+ ) -> torch .Tensor :
1064
1057
kv = self .kv_b_proj (compressed_kv )
1065
1058
k_nope , v = kv .split (
1066
1059
[
@@ -1099,7 +1092,7 @@ def forward_context_with_cached_kv(
1099
1092
q : torch .Tensor ,
1100
1093
latent_cache : torch .Tensor ,
1101
1094
attn_metadata : AttentionMetadata ,
1102
- output : Optional [ torch .Tensor ] = None ,
1095
+ output : torch .Tensor ,
1103
1096
) -> torch .Tensor :
1104
1097
assert latent_cache is not None
1105
1098
trtllm_attention = cast (TrtllmAttention , self .mha )
@@ -1168,7 +1161,7 @@ def forward_context_with_chunked_prefill(
1168
1161
latent_cache : torch .
1169
1162
Tensor , # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
1170
1163
attn_metadata : TrtllmAttentionMetadata ,
1171
- output : Optional [ torch .Tensor ] = None ,
1164
+ output : torch .Tensor ,
1172
1165
) -> torch .Tensor :
1173
1166
trtllm_attention = cast (TrtllmAttention , self .mha )
1174
1167
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1190,11 +1183,8 @@ def forward_context_with_chunked_prefill(
1190
1183
dtype = torch .float ,
1191
1184
device = 'cuda' ,
1192
1185
)
1193
- if output is None :
1194
- attn_output = q .new_empty (
1195
- (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1196
- else :
1197
- attn_output = output
1186
+
1187
+ attn_output = output
1198
1188
temp_attn_output = q .new_empty (
1199
1189
(q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1200
1190
@@ -1332,8 +1322,8 @@ def forward_context(
1332
1322
compressed_kv : torch .Tensor ,
1333
1323
k_pe : torch .Tensor ,
1334
1324
attn_metadata : AttentionMetadata ,
1325
+ output : torch .Tensor ,
1335
1326
latent_cache : Optional [torch .Tensor ] = None ,
1336
- output : Optional [torch .Tensor ] = None ,
1337
1327
) -> torch .Tensor :
1338
1328
if isinstance (self .mha , TrtllmAttention ):
1339
1329
assert isinstance (attn_metadata , TrtllmAttentionMetadata )
@@ -1346,16 +1336,17 @@ def forward_context(
1346
1336
return self .forward_context_with_cached_kv (
1347
1337
q , latent_cache , attn_metadata , output )
1348
1338
return self .forward_context_default (q , compressed_kv , k_pe ,
1349
- attn_metadata , latent_cache , output )
1339
+ attn_metadata , output , latent_cache )
1350
1340
1351
1341
def forward_generation (
1352
- self ,
1353
- q : torch .Tensor ,
1354
- compressed_kv : torch .Tensor ,
1355
- k_pe : torch .Tensor ,
1356
- attn_metadata : AttentionMetadata ,
1357
- latent_cache : Optional [torch .Tensor ] = None ,
1358
- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1342
+ self ,
1343
+ q : torch .Tensor ,
1344
+ compressed_kv : torch .Tensor ,
1345
+ k_pe : torch .Tensor ,
1346
+ attn_metadata : AttentionMetadata ,
1347
+ output : torch .Tensor ,
1348
+ latent_cache : Optional [torch .Tensor ] = None ,
1349
+ ) -> torch .Tensor :
1359
1350
num_tokens = q .shape [0 ]
1360
1351
q_nope , q_pe = q .view ([- 1 , self .num_heads , self .qk_head_dim ]).split (
1361
1352
[self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
@@ -1427,12 +1418,6 @@ def forward_generation(
1427
1418
attn_out_latent = attn_out_latent .view (
1428
1419
[- 1 , self .num_heads , self .kv_lora_rank ])
1429
1420
1430
- # [seq, num_heads * v_head_dim]
1431
- output = output if output is not None else torch .empty (
1432
- [num_tokens , self .num_heads * self .v_head_dim ],
1433
- dtype = attn_out_latent .dtype ,
1434
- device = attn_out_latent .device )
1435
-
1436
1421
attn_output = output .view ([num_tokens , self .num_heads , self .v_head_dim ])
1437
1422
1438
1423
if self .v_b_proj .dtype == torch .bfloat16 :
0 commit comments