@@ -338,6 +338,19 @@ def _attn_impl(
338
338
attention_sinks : Optional [torch .Tensor ] = None ,
339
339
):
340
340
341
+ padded_num_tokens = attn_metadata .padded_num_tokens
342
+ num_tokens = attn_metadata .num_tokens
343
+
344
+ if padded_num_tokens is not None :
345
+ assert q .shape [0 ] == padded_num_tokens
346
+ q = q [:num_tokens , :]
347
+ if k is not None :
348
+ assert k .shape [0 ] == padded_num_tokens
349
+ k = k [:num_tokens , :]
350
+ if v is not None :
351
+ assert v .shape [0 ] == padded_num_tokens
352
+ v = v [:num_tokens , :]
353
+
341
354
out_scale = None
342
355
out_scale_sf = None
343
356
has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
@@ -368,7 +381,7 @@ def _attn_impl(
368
381
attention_window_size = attention_window_size ,
369
382
attention_mask_data = attention_mask_data ,
370
383
enable_attn_nvfp4_output = enable_attn_nvfp4_output ,
371
- output = output ,
384
+ output = output [: num_tokens , :] if output is not None else None ,
372
385
output_sf = output_sf ,
373
386
attention_sinks = attention_sinks )
374
387
if isinstance (attn_output , tuple ):
@@ -937,11 +950,10 @@ def create_output(self, hidden_states: torch.Tensor):
937
950
return hidden_states .new_empty ([num_tokens , hidden_size ],
938
951
dtype = hidden_states .dtype )
939
952
940
- def forward_impl (self ,
941
- position_ids : Optional [torch .Tensor ],
953
+ def forward_impl (self , position_ids : Optional [torch .Tensor ],
942
954
hidden_states : torch .Tensor ,
943
955
attn_metadata : AttentionMetadata ,
944
- output : Optional [ torch .Tensor ] = None ) -> torch . Tensor :
956
+ output : torch .Tensor ) -> None :
945
957
"""
946
958
Forward pass for the MLA module.
947
959
@@ -954,6 +966,18 @@ def forward_impl(self,
954
966
Returns:
955
967
torch.Tensor: The output tensor.
956
968
"""
969
+ # split q, k, v into context and gen batches
970
+ num_contexts = attn_metadata .num_contexts
971
+ num_generations = attn_metadata .num_generations
972
+ num_ctx_tokens = attn_metadata .num_ctx_tokens
973
+ num_tokens = attn_metadata .num_tokens
974
+ padded_num_tokens = attn_metadata .padded_num_tokens
975
+
976
+ if padded_num_tokens is not None :
977
+ hidden_states = hidden_states [:num_tokens , ...]
978
+ if position_ids is not None :
979
+ position_ids = position_ids [:num_tokens , ...]
980
+
957
981
if self .is_lite :
958
982
compressed_kv , k_pe = self .kv_a_proj_with_mqa (hidden_states ).split (
959
983
[self .kv_lora_rank , self .qk_rope_head_dim ], - 1 )
@@ -981,15 +1005,11 @@ def forward_impl(self,
981
1005
self .aux_stream ,
982
1006
)
983
1007
984
- # split q, k, v into context and gen batches
985
- num_contexts = attn_metadata .num_contexts
986
- num_generations = attn_metadata .num_generations
987
- num_ctx_tokens = attn_metadata .num_ctx_tokens
988
- num_tokens = attn_metadata .num_tokens
989
-
990
1008
assert q .shape [
991
1009
0 ] == num_tokens , f"Expect q.shape[0] to be { num_tokens } , but got { q .shape [0 ]} "
992
1010
1011
+ assert output is not None , "output must be provided"
1012
+
993
1013
if num_contexts > 0 :
994
1014
q_ctx = q [:num_ctx_tokens , ...]
995
1015
compressed_kv_ctx = compressed_kv [:num_ctx_tokens , ...]
@@ -999,17 +1019,14 @@ def forward_impl(self,
999
1019
assert position_ids is not None
1000
1020
k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
1001
1021
1002
- attn_output_context = self .forward_context (
1022
+ self .forward_context (
1003
1023
q_ctx ,
1004
1024
compressed_kv_ctx ,
1005
1025
k_pe_ctx ,
1006
1026
attn_metadata ,
1027
+ output [:num_ctx_tokens , :],
1007
1028
latent_cache_ctx ,
1008
- output = output if num_generations == 0 else None )
1009
- if num_generations == 0 :
1010
- return attn_output_context
1011
- else :
1012
- attn_output_context = None
1029
+ )
1013
1030
1014
1031
if num_generations > 0 :
1015
1032
q_gen = q [num_ctx_tokens :, ...]
@@ -1020,39 +1037,14 @@ def forward_impl(self,
1020
1037
assert position_ids is not None
1021
1038
k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
1022
1039
1023
- attn_output_gen = self .forward_generation (
1040
+ self .forward_generation (
1024
1041
q_gen ,
1025
1042
compressed_kv_gen ,
1026
1043
k_pe_gen ,
1027
1044
attn_metadata ,
1045
+ output [num_ctx_tokens :num_tokens , :],
1028
1046
latent_cache_gen ,
1029
- output = output if num_contexts == 0 else None )
1030
- if num_contexts == 0 :
1031
- return attn_output_gen
1032
- else :
1033
- attn_output_gen = None
1034
-
1035
- # release pytorch activation memory
1036
- q = None
1037
- compressed_kv = None
1038
- k_pe = None
1039
-
1040
- assert attn_output_context is not None and attn_output_gen is not None
1041
- assert (
1042
- len (attn_output_context .shape ) == 2
1043
- ), f"attn_output_context must be rank 2, not { len (attn_output_context .shape )} "
1044
- assert (
1045
- len (attn_output_gen .shape ) == 2
1046
- ), f"attn_output_gen must be rank 2, not { len (attn_output_gen .shape )} "
1047
- output = output if output is not None else torch .empty (
1048
- (num_tokens , attn_output_context .shape [1 ]),
1049
- dtype = attn_output_context .dtype ,
1050
- device = attn_output_context .device )
1051
- output [:attn_output_context .shape [0 ], :] = attn_output_context
1052
- output [attn_output_context .shape [0 ]:, :] = attn_output_gen
1053
- attn_output_context = None
1054
- attn_output_gen = None
1055
- return output
1047
+ )
1056
1048
1057
1049
def _maybe_concat_qkv (self , q , k , v ):
1058
1050
if k is not None and v is not None and self .support_fused_qkv :
@@ -1061,13 +1053,14 @@ def _maybe_concat_qkv(self, q, k, v):
1061
1053
return q , k , v
1062
1054
1063
1055
def forward_context_default (
1064
- self ,
1065
- q : torch .Tensor ,
1066
- compressed_kv : torch .Tensor ,
1067
- k_pe : torch .Tensor ,
1068
- attn_metadata : AttentionMetadata ,
1069
- latent_cache : Optional [torch .Tensor ] = None ,
1070
- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1056
+ self ,
1057
+ q : torch .Tensor ,
1058
+ compressed_kv : torch .Tensor ,
1059
+ k_pe : torch .Tensor ,
1060
+ attn_metadata : AttentionMetadata ,
1061
+ output : torch .Tensor ,
1062
+ latent_cache : Optional [torch .Tensor ] = None ,
1063
+ ) -> torch .Tensor :
1071
1064
kv = self .kv_b_proj (compressed_kv )
1072
1065
k_nope , v = kv .split (
1073
1066
[
@@ -1109,7 +1102,7 @@ def forward_context_with_cached_kv(
1109
1102
q : torch .Tensor ,
1110
1103
latent_cache : torch .Tensor ,
1111
1104
attn_metadata : AttentionMetadata ,
1112
- output : Optional [ torch .Tensor ] = None ,
1105
+ output : torch .Tensor ,
1113
1106
) -> torch .Tensor :
1114
1107
assert latent_cache is not None
1115
1108
trtllm_attention = cast (TrtllmAttention , self .mha )
@@ -1195,7 +1188,7 @@ def forward_context_with_chunked_prefill(
1195
1188
latent_cache : torch .
1196
1189
Tensor , # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
1197
1190
attn_metadata : TrtllmAttentionMetadata ,
1198
- output : Optional [ torch .Tensor ] = None ,
1191
+ output : torch .Tensor ,
1199
1192
) -> torch .Tensor :
1200
1193
trtllm_attention = cast (TrtllmAttention , self .mha )
1201
1194
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1218,11 +1211,8 @@ def forward_context_with_chunked_prefill(
1218
1211
dtype = torch .float ,
1219
1212
device = 'cuda' ,
1220
1213
)
1221
- if output is None :
1222
- attn_output = q .new_empty (
1223
- (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1224
- else :
1225
- attn_output = output
1214
+
1215
+ attn_output = output
1226
1216
temp_attn_output = q .new_empty (
1227
1217
(q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1228
1218
@@ -1354,8 +1344,8 @@ def forward_context(
1354
1344
compressed_kv : torch .Tensor ,
1355
1345
k_pe : torch .Tensor ,
1356
1346
attn_metadata : AttentionMetadata ,
1347
+ output : torch .Tensor ,
1357
1348
latent_cache : Optional [torch .Tensor ] = None ,
1358
- output : Optional [torch .Tensor ] = None ,
1359
1349
) -> torch .Tensor :
1360
1350
if isinstance (self .mha , TrtllmAttention ):
1361
1351
assert isinstance (attn_metadata , TrtllmAttentionMetadata )
@@ -1368,16 +1358,17 @@ def forward_context(
1368
1358
return self .forward_context_with_cached_kv (
1369
1359
q , latent_cache , attn_metadata , output )
1370
1360
return self .forward_context_default (q , compressed_kv , k_pe ,
1371
- attn_metadata , latent_cache , output )
1361
+ attn_metadata , output , latent_cache )
1372
1362
1373
1363
def forward_generation (
1374
- self ,
1375
- q : torch .Tensor ,
1376
- compressed_kv : torch .Tensor ,
1377
- k_pe : torch .Tensor ,
1378
- attn_metadata : AttentionMetadata ,
1379
- latent_cache : Optional [torch .Tensor ] = None ,
1380
- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1364
+ self ,
1365
+ q : torch .Tensor ,
1366
+ compressed_kv : torch .Tensor ,
1367
+ k_pe : torch .Tensor ,
1368
+ attn_metadata : AttentionMetadata ,
1369
+ output : torch .Tensor ,
1370
+ latent_cache : Optional [torch .Tensor ] = None ,
1371
+ ) -> torch .Tensor :
1381
1372
num_tokens = q .shape [0 ]
1382
1373
q_nope , q_pe = q .view ([- 1 , self .num_heads , self .qk_head_dim ]).split (
1383
1374
[self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
@@ -1449,12 +1440,6 @@ def forward_generation(
1449
1440
attn_out_latent = attn_out_latent .view (
1450
1441
[- 1 , self .num_heads , self .kv_lora_rank ])
1451
1442
1452
- # [seq, num_heads * v_head_dim]
1453
- output = output if output is not None else torch .empty (
1454
- [num_tokens , self .num_heads * self .v_head_dim ],
1455
- dtype = attn_out_latent .dtype ,
1456
- device = attn_out_latent .device )
1457
-
1458
1443
attn_output = output .view ([num_tokens , self .num_heads , self .v_head_dim ])
1459
1444
1460
1445
if self .v_b_proj .dtype == torch .bfloat16 :
0 commit comments