66import triton
77import triton .language as tl
88import triton .language .extra .tlx as tlx
9+ from triton .tools .tensor_descriptor import TensorDescriptor
910
1011from .gdpa_utils import get_num_sms
1112from .math import activation_string_to_int
1213
1314
15+ def _host_descriptor_pre_hook (nargs ):
16+ BLOCK_M = nargs ["BLOCK_M" ]
17+ BLOCK_N = nargs ["BLOCK_N" ]
18+ BLOCK_D = nargs ["BLOCK_D" ]
19+ if not isinstance (nargs ["Q" ], TensorDescriptor ):
20+ # early return for on-device TMA
21+ return
22+ NUM_MMA_GROUPS = 2
23+ BLOCK_M_SPLIT = BLOCK_M // NUM_MMA_GROUPS
24+ nargs ["Q" ].block_shape = [BLOCK_M_SPLIT , BLOCK_D ]
25+ nargs ["V" ].block_shape = [BLOCK_N , BLOCK_D ]
26+ nargs ["K" ].block_shape = [BLOCK_N , BLOCK_D ]
27+ nargs ["Out" ].block_shape = [BLOCK_M_SPLIT , BLOCK_D ]
28+
29+
1430def get_cuda_autotune_config ():
1531 return [
1632 triton .Config (
@@ -24,6 +40,7 @@ def get_cuda_autotune_config():
2440 },
2541 num_warps = 4 ,
2642 num_stages = 1 ,
43+ pre_hook = _host_descriptor_pre_hook ,
2744 )
2845 for BM in [256 ] # 128 or 256
2946 for BN in [128 ]
@@ -198,6 +215,7 @@ def gdpa_kernel_tma_ws_blackwell(
198215 BROADCAST_Q : tl .constexpr ,
199216 IS_DENSE_KV : tl .constexpr ,
200217 activation_enum_int : tl .constexpr ,
218+ USE_ON_DEVICE_TMA : tl .constexpr ,
201219 NUM_BUFFERS_Q : tl .constexpr ,
202220 NUM_BUFFERS_KV : tl .constexpr ,
203221 NUM_BUFFERS_QK : tl .constexpr ,
@@ -214,21 +232,27 @@ def gdpa_kernel_tma_ws_blackwell(
214232 tiles_per_sm += 1
215233
216234 tile_idx = prog_id
235+ if not USE_ON_DEVICE_TMA :
236+ q_desc = Q
237+ k_desc = K
238+ v_desc = V
239+ o_desc = Out
217240
218241 # start with on-device TMA where descriptors for k, v are set up outside of the persistent
219242 # loop and descriptor for q is set up inside the persistent loop.
220- k_desc = tl .make_tensor_descriptor (
221- K ,
222- shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
223- strides = [HEAD_DIM * H // G , 1 ],
224- block_shape = [BLOCK_N , BLOCK_D ],
225- )
226- v_desc = tl .make_tensor_descriptor (
227- V ,
228- shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
229- strides = [HEAD_DIM * H // G , 1 ],
230- block_shape = [BLOCK_N , BLOCK_D ],
231- )
243+ if USE_ON_DEVICE_TMA :
244+ k_desc = tl .make_tensor_descriptor (
245+ K ,
246+ shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
247+ strides = [HEAD_DIM * H // G , 1 ],
248+ block_shape = [BLOCK_N , BLOCK_D ],
249+ )
250+ v_desc = tl .make_tensor_descriptor (
251+ V ,
252+ shape = [N_CTX_KV * Z , HEAD_DIM * H // G ],
253+ strides = [HEAD_DIM * H // G , 1 ],
254+ block_shape = [BLOCK_N , BLOCK_D ],
255+ )
232256
233257 # allocate buffers for q0, q1
234258 q0_buf = tlx .local_alloc ((BLOCK_M // 2 , BLOCK_D ), tl .float16 , 1 )
@@ -326,20 +350,12 @@ def gdpa_kernel_tma_ws_blackwell(
326350 qk0 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
327351 # ConsumerWait for qk, ProducerAcquire for p
328352 # if activation_enum_int == 3:
329- p0 = (
330- qk0
331- * 0.5
332- * (
333- 1
334- + tanh_approx_fp32 (
335- 0.7978845608 * qk0 * (1.0 + 0.044715 * qk0 * qk0 )
336- )
337- )
338- ) # fast_gelu(qk0)
339- # else:
340- # p0 = qk0
353+ p0 = fast_gelu (qk0 )
341354 p0 *= qk_scale
342- p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
355+ if USE_ON_DEVICE_TMA :
356+ p0 = p0 .to (V .dtype .element_ty ) # v_dtype)
357+ else :
358+ p0 = p0 .to (tlx .dtype_of (v_desc ))
343359 qk_view = tlx .local_view (qk0_buf , bufIdx )
344360 p0_view = tlx .local_reinterpret (qk_view , tl .float16 )
345361 tlx .local_store (p0_view , p0 ) # , tlx.storage_kind.tmem)
@@ -371,18 +387,23 @@ def gdpa_kernel_tma_ws_blackwell(
371387 )
372388 # tl.device_print("default producer_o0", accum_cnt_outer)
373389 tlx .barrier_arrive (consumer_release_o0_view , 1 )
374- o0_desc = tl .make_tensor_descriptor (
375- Out ,
376- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
377- strides = [HEAD_DIM * H , 1 ],
378- block_shape = [BLOCK_M // 2 , BLOCK_D ],
379- )
380- o0_desc .store (
390+ if USE_ON_DEVICE_TMA :
391+ o_desc = tl .make_tensor_descriptor (
392+ Out ,
393+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
394+ strides = [HEAD_DIM * H , 1 ],
395+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
396+ )
397+ if USE_ON_DEVICE_TMA :
398+ o0 = o0 .to (Out .type .element_ty )
399+ else :
400+ o0 = o0 .to (tlx .dtype_of (o_desc ))
401+ o_desc .store (
381402 [
382403 (begin_q + start_m * BLOCK_M ).to (tl .int32 ),
383404 (out_offset ).to (tl .int32 ),
384405 ],
385- o0 . to ( Out . type . element_ty ) ,
406+ o0 ,
386407 )
387408 accum_cnt_outer += 1
388409 tile_idx += num_progs
@@ -420,20 +441,12 @@ def gdpa_kernel_tma_ws_blackwell(
420441 qk1 = tlx .local_load (qk_view ) # , tlx.storage_kind.tmem)
421442 # ConsumerWait for qk, ProducerAcquire for p
422443 # if activation_enum_int == 3:
423- p1 = (
424- qk1
425- * 0.5
426- * (
427- 1
428- + tanh_approx_fp32 (
429- 0.7978845608 * qk1 * (1.0 + 0.044715 * qk1 * qk1 )
430- )
431- )
432- ) # fast_gelu(qk1)
433- # else:
434- # p1 = qk1
444+ p1 = fast_gelu (qk1 )
435445 p1 *= qk_scale
436- p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
446+ if USE_ON_DEVICE_TMA :
447+ p1 = p1 .to (V .dtype .element_ty ) # v_dtype)
448+ else :
449+ p1 = p1 .to (tlx .dtype_of (v_desc ))
437450 qk_view = tlx .local_view (qk1_buf , bufIdx )
438451 p1_view = tlx .local_reinterpret (qk_view , tl .float16 )
439452 tlx .local_store (p1_view , p1 ) # , tlx.storage_kind.tmem)
@@ -452,12 +465,13 @@ def gdpa_kernel_tma_ws_blackwell(
452465 bufIdx_o_outer , phase_o_outer = _get_bufidx_phase (
453466 accum_cnt_outer , NUM_BUFFERS_O
454467 )
455- o0_desc = tl .make_tensor_descriptor (
456- Out ,
457- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
458- strides = [HEAD_DIM * H , 1 ],
459- block_shape = [BLOCK_M // 2 , BLOCK_D ],
460- )
468+ if USE_ON_DEVICE_TMA :
469+ o_desc = tl .make_tensor_descriptor (
470+ Out ,
471+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
472+ strides = [HEAD_DIM * H , 1 ],
473+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
474+ )
461475 o1_view = tlx .local_view (
462476 o1_buf , bufIdx_o_outer
463477 ) # FIXME: should be 0
@@ -467,12 +481,16 @@ def gdpa_kernel_tma_ws_blackwell(
467481 producer_o1 , bufIdx_o_outer
468482 )
469483 tlx .barrier_arrive (consumer_release_o1_view , 1 )
470- o0_desc .store (
484+ if USE_ON_DEVICE_TMA :
485+ o1 = o1 .to (Out .type .element_ty )
486+ else :
487+ o1 = o1 .to (tlx .dtype_of (o_desc ))
488+ o_desc .store (
471489 [
472490 (begin_q + start_m * BLOCK_M + BLOCK_M // 2 ).to (tl .int32 ),
473491 (out_offset ).to (tl .int32 ),
474492 ],
475- o1 . to ( Out . type . element_ty ) ,
493+ o1 ,
476494 )
477495 accum_cnt_outer += 1
478496 tile_idx += num_progs
@@ -581,6 +599,7 @@ def gdpa_kernel_tma_ws_blackwell(
581599 producer_o1_view = tlx .local_view (producer_o1 , bufIdx_o_outer )
582600 # tl.device_print("gemm producer_o0", accum_cnt_outer)
583601 # tl.device_print("gemm producer_o0_phase", phase_o_outer)
602+ # DEBUG_PERF
584603 tlx .barrier_wait (
585604 producer_o0_view , phase_o_outer ^ 1
586605 ) # producer acquire for o0
@@ -591,6 +610,7 @@ def gdpa_kernel_tma_ws_blackwell(
591610 consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_p )
592611 # tl.device_print("gemm producer_qk0", accum_cnt_qk)
593612 # tl.device_print("gemm producer_qk0_phase", phase_p)
613+ # DEBUG_PERF_P
594614 tlx .barrier_wait (
595615 consumer_p0_view , phase_p
596616 ) # consumer wait for p0 due to reuse of p0 and qk0
@@ -660,11 +680,13 @@ def gdpa_kernel_tma_ws_blackwell(
660680 consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
661681 # tl.device_print("gemm producer_o1", accum_cnt_outer)
662682 # tl.device_print("gemm producer_o1_phase", phase_o_outer)
683+ # DEBUG_PERF
663684 tlx .barrier_wait (
664685 producer_o1_view , phase_o_outer ^ 1 , first
665686 ) # producer acquire for o1, only needed for first iteration
666687 # tl.device_print("gemm producer_qk1", accum_cnt_qk1)
667688 # tl.device_print("gemm producer_qk1_phase", phase_qk1)
689+ # DEBUG_PERF_P
668690 tlx .barrier_wait (
669691 consumer_p1_view , phase_qk1
670692 ) # consumer wait for p1 use producer_qk1 due to reuse
@@ -741,6 +763,7 @@ def gdpa_kernel_tma_ws_blackwell(
741763 consumer_p0_view = tlx .local_view (producer_qk0 , bufIdx_qk )
742764 # tl.device_print("gemm producer_qk0", accum_cnt_qk)
743765 # tl.device_print("gemm producer_qk0_phase", phase_qk)
766+ # DEBUG_PERF_P
744767 tlx .barrier_wait (
745768 consumer_p0_view , phase_qk
746769 ) # consumer wait for p0 use producer_qk0 due to reuse
@@ -780,6 +803,7 @@ def gdpa_kernel_tma_ws_blackwell(
780803 tlx .tcgen05_commit (release_q1_view )
781804 # tl.device_print("gemm producer_o1_epilogue", accum_cnt_outer)
782805 # tl.device_print("gemm producer_o1_phase", phase_o_outer)
806+ # DEBUG_PERF
783807 tlx .barrier_wait (
784808 producer_o1_view , phase_o_outer ^ 1 , first
785809 ) # producer acquire for o1 at the first iteration
@@ -789,6 +813,7 @@ def gdpa_kernel_tma_ws_blackwell(
789813 consumer_p1_view = tlx .local_view (producer_qk1 , bufIdx_qk1 )
790814 # tl.device_print("gemm producer_qk1_epilogue", accum_cnt_qk1)
791815 # tl.device_print("gemm producer_qk1_phase", phase_qk1)
816+ # DEBUG_PERF_P
792817 tlx .barrier_wait (
793818 consumer_p1_view , phase_qk1
794819 ) # consumer wait for p1 due to reuse of p1 and qk1
@@ -862,12 +887,13 @@ def gdpa_kernel_tma_ws_blackwell(
862887 if start_m * BLOCK_M < qlen :
863888 # begin_o = tl.load(Out_offsets + off_z) # confirm if tma store should use begin_q
864889
865- q_desc = tl .make_tensor_descriptor (
866- Q ,
867- shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
868- strides = [HEAD_DIM * H , 1 ],
869- block_shape = [BLOCK_M // 2 , BLOCK_D ],
870- )
890+ if USE_ON_DEVICE_TMA :
891+ q_desc = tl .make_tensor_descriptor (
892+ Q ,
893+ shape = [end_q .to (tl .int32 ), HEAD_DIM * H ],
894+ strides = [HEAD_DIM * H , 1 ],
895+ block_shape = [BLOCK_M // 2 , BLOCK_D ],
896+ )
871897
872898 # calculate bufIdx and phase from accum_count_q
873899 q_bufIdx = accum_count_q % NUM_BUFFERS_Q
@@ -1131,6 +1157,40 @@ def gdpa_forward_tlx(
11311157 print ("NUM_SMS" , NUM_SMS )
11321158 print (triton .cdiv (max_seq_len_q , 256 ) * BATCH * nheads )
11331159
1160+ q = expect_contiguous (query )
1161+ k = expect_contiguous (key )
1162+ v = expect_contiguous (value )
1163+ kstrides = k .stride ()
1164+ vstrides = v .stride ()
1165+
1166+ dummy_block = [1 , 1 ]
1167+ N_CTX_KV = max_seq_len_kv
1168+ HEAD_DIM = HEAD_DIM_K
1169+ Z = BATCH
1170+ H = nheads
1171+ y_dim = N_CTX_KV * Z
1172+ x_dim = HEAD_DIM * H // G
1173+ USE_ON_DEVICE_TMA = True
1174+ if not USE_ON_DEVICE_TMA :
1175+ desc_q = TensorDescriptor (
1176+ q ,
1177+ shape = [y_dim , HEAD_DIM * H ],
1178+ strides = [HEAD_DIM * H , 1 ],
1179+ block_shape = dummy_block ,
1180+ )
1181+ desc_v = TensorDescriptor (
1182+ v , shape = [y_dim , x_dim ], strides = [x_dim , 1 ], block_shape = dummy_block
1183+ )
1184+ desc_k = TensorDescriptor (
1185+ k , shape = [y_dim , x_dim ], strides = [x_dim , 1 ], block_shape = dummy_block
1186+ )
1187+ desc_o = TensorDescriptor (
1188+ o ,
1189+ shape = [y_dim , HEAD_DIM * H ],
1190+ strides = [HEAD_DIM * H , 1 ],
1191+ block_shape = dummy_block ,
1192+ )
1193+
11341194 # TMA descriptors require a global memory allocation
11351195 def alloc_fn (size : int , alignment : int , _ ):
11361196 return torch .empty (size , device = "cuda" , dtype = torch .int8 )
@@ -1144,22 +1204,19 @@ def grid_tma_persistent(META):
11441204 1 ,
11451205 )
11461206
1147- q = expect_contiguous (query )
1148- k = expect_contiguous (key )
1149- v = expect_contiguous (value )
1150- kstrides = k .stride ()
1151- vstrides = v .stride ()
1152-
11531207 activation_enum_int = activation_string_to_int (activation )
1208+ print (q .shape , k .shape , v .shape )
11541209 # print("activation_enum_int", activation, activation_enum_int)
1210+ # print(query_offset)
1211+ # print(key_offset)
11551212
11561213 gdpa_kernel_tma_ws_blackwell [grid_tma_persistent ](
1157- q ,
1214+ q if USE_ON_DEVICE_TMA else desc_q ,
11581215 query_offset ,
1159- k ,
1216+ k if USE_ON_DEVICE_TMA else desc_k ,
11601217 key_offset ,
1161- v ,
1162- o , #
1218+ v if USE_ON_DEVICE_TMA else desc_v ,
1219+ o if USE_ON_DEVICE_TMA else desc_o ,
11631220 output_offset ,
11641221 ad_to_request_offset ,
11651222 seq_index ,
@@ -1194,6 +1251,7 @@ def grid_tma_persistent(META):
11941251 BROADCAST_Q = broadcast_q ,
11951252 IS_DENSE_KV = is_dense_kv ,
11961253 activation_enum_int = activation_enum_int ,
1254+ USE_ON_DEVICE_TMA = USE_ON_DEVICE_TMA ,
11971255 ** extra_kern_args ,
11981256 )
11991257 return o
0 commit comments