27
27
28
28
namespace tensorrt_llm ::kernels::mnnvl
29
29
{
30
+
31
+ // Guard for internal helper functions
32
+ namespace
33
+ {
30
34
__device__ bool isNegZero (float v)
31
35
{
32
36
return v == 0 .f && signbit (v);
@@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
49
53
return __bfloat162float (val);
50
54
}
51
55
56
+ template <>
57
+ inline __device__ float toFloat<__nv_half>(__nv_half val)
58
+ {
59
+ return __half2float (val);
60
+ }
61
+
52
62
template <typename T>
53
63
inline __device__ T fromFloat (float val)
54
64
{
@@ -61,30 +71,76 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
61
71
return __float2bfloat16 (val);
62
72
}
63
73
64
- __device__ float4 loadfloat4 (void const * ptr)
74
+ template <>
75
+ inline __device__ __nv_half fromFloat<__nv_half>(float val)
65
76
{
77
+ return __float2half (val);
78
+ }
66
79
67
- float return_value[4 ];
68
-
69
- asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
70
- : " =f" (return_value[0 ]), " =f" (return_value[1 ]), " =f" (return_value[2 ]), " =f" (return_value[3 ])
71
- : " l" (ptr));
72
-
73
- return *(float4 *) return_value;
80
+ inline __device__ float2 loadfloat2 (void const * ptr)
81
+ {
82
+ float2 return_value;
83
+ asm volatile (" ld.volatile.global.v2.f32 {%0, %1}, [%2];\n " : " =f" (return_value.x ), " =f" (return_value.y ) : " l" (ptr));
84
+ return return_value;
74
85
}
75
86
76
- __device__ __inline__ float2 loadfloat2 (void const * ptr)
87
+ template <typename T>
88
+ inline __device__ T divUp (T val, T divisor)
77
89
{
90
+ return (val + divisor - 1 ) / divisor;
91
+ }
78
92
79
- float return_value[2 ];
93
+ __device__ struct __attribute__ ((aligned(32 ))) LamportFlags
94
+ {
95
+ uint32_t buffer_size;
96
+ uint32_t input_offset;
97
+ uint32_t clear_offset;
98
+ uint32_t num_tokens_prev;
99
+ uint32_t * offset_access_ptr;
100
+ uint32_t * buffer_flags;
101
+
102
+ __device__ explicit LamportFlags (uint32_t * buffer_flags)
103
+ : offset_access_ptr (&buffer_flags[4 ])
104
+ , buffer_flags (buffer_flags)
105
+ {
106
+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
107
+ buffer_size = flag.z ;
108
+ input_offset = flag.x * (buffer_size << 1U );
109
+ clear_offset = flag.y * (buffer_size << 1U );
110
+ num_tokens_prev = flag.w ;
111
+ }
80
112
81
- asm volatile (" ld.volatile.global.v2.f32 {%0, %1}, [%2];\n "
82
- : " =f" (return_value[0 ]), " =f" (return_value[1 ])
83
- : " l" (ptr)
84
- : " memory" );
113
+ __device__ void cta_arrive ()
114
+ {
115
+ __syncthreads ();
116
+ if (threadIdx .x == 0 )
117
+ {
118
+ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
119
+ asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
120
+ #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
121
+ asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
122
+ #else
123
+ atomicAdd (offset_access_ptr, 1 );
124
+ #endif
125
+ }
126
+ }
85
127
86
- return *(float2 *) return_value;
87
- }
128
+ __device__ void wait_and_update (uint32_t num_tokens)
129
+ {
130
+ if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
131
+ {
132
+ while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
133
+ {
134
+ }
135
+ uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
136
+ buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
137
+ buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
138
+ buffer_flags[3 ] = num_tokens;
139
+ *(offset_access_ptr) = 0 ;
140
+ }
141
+ }
142
+ };
143
+ } // namespace
88
144
89
145
template <int WORLD_SIZE, typename T>
90
146
__global__ void twoshot_allreduce_kernel (T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
@@ -99,13 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
99
155
cudaGridDependencySynchronize ();
100
156
#endif
101
157
102
- // [input_ptr, clear_ptr, buffer_size, access_counter]
103
- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
104
- // Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
105
- uint32_t buffer_group_size = flag.z << 1 ;
106
- uint32_t input_offset = flag.x * buffer_group_size;
107
- uint32_t clear_offset = flag.y * buffer_group_size;
108
- uint32_t * offset_access_ptr = &buffer_flags[3 ];
158
+ LamportFlags flags (buffer_flags);
159
+
160
+ // Capture the number of tokens in previous iteration so that we can properly clear the buffer
161
+ // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
162
+ uint32_t clr_toks_cta
163
+ = divUp<uint32_t >(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
164
+ * WORLD_SIZE;
165
+ clr_toks_cta = divUp<uint32_t >(clr_toks_cta, gridDim .x );
109
166
110
167
if (elt < token_dim)
111
168
{
@@ -115,29 +172,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
115
172
T val = shard_ptr[token * token_dim + elt];
116
173
if (isNegZero (val))
117
174
val = fromFloat<T>(0 .f );
118
- input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
175
+ input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
176
+ = val;
119
177
120
- // Reduce and broadcast
178
+ // Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
179
+ // number of tokens in the current call.
180
+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta; clr_tok++)
181
+ {
182
+ uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
183
+ if (clr_token_idx < buffer_M)
184
+ {
185
+ input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0 .f );
186
+ }
187
+ }
121
188
189
+ // Reduce and broadcast
122
190
if ((token % WORLD_SIZE) == rank)
123
191
{
124
192
int local_token = token / WORLD_SIZE;
125
193
float accum = 0 .f ;
126
194
127
195
T values[WORLD_SIZE];
128
-
129
- for (int r = 0 ; r < WORLD_SIZE; r++)
130
- {
131
- input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
132
- = fromFloat<T>(-0 .f );
133
- }
134
-
135
196
while (1 )
136
197
{
137
198
bool valid = true ;
138
199
for (int r = 0 ; r < WORLD_SIZE; r++)
139
200
{
140
- T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][input_offset
201
+ T volatile * lamport_ptr = (T volatile *) &input_ptrs[rank][flags. input_offset
141
202
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
142
203
values[r] = *lamport_ptr;
143
204
valid &= !isNegZero (values[r]);
@@ -149,40 +210,39 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
149
210
{
150
211
accum += toFloat<T>(values[r]);
151
212
}
152
- mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
213
+ mcast_ptr[flags. input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
153
214
}
154
215
}
155
216
156
217
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
157
218
cudaTriggerProgrammaticLaunchCompletion ();
158
219
#endif
159
220
160
- input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(-0 .f );
221
+ // Similarly clear broadcast buffer here
222
+ for (int clr_tok = 0 ; clr_tok < clr_toks_cta; clr_tok++)
223
+ {
224
+ uint32_t clr_token_idx = token + clr_tok * gridDim .x ;
225
+ if (clr_token_idx < buffer_M)
226
+ {
227
+ input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
228
+ = fromFloat<T>(-0 .f );
229
+ }
230
+ }
161
231
162
232
// Optionally wait for results if the next layer isn't doing the Lamport check
163
233
if (wait_for_results)
164
234
{
165
235
// Update the atomic counter to indicate the block has read the offsets
166
- __syncthreads ();
236
+ flags. cta_arrive ();
167
237
168
- if (threadIdx .x == 0 )
169
- {
170
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
171
- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
172
- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
173
- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
174
- #else
175
- atomicAdd (offset_access_ptr, 1 );
176
- #endif
177
- }
178
238
// Only use a set of CTAs for lamport sync, reargange the grid
179
239
constexpr int ELTS_PER_LOAD = sizeof (float2 ) / sizeof (T);
180
240
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
181
241
if (threadIdx .x < (blockDim .x / ELTS_PER_LOAD))
182
242
{
183
243
uint64_t current_pos = blockIdx .x * token_dim + blockIdx .y * blockDim .x + threadIdx .x * ELTS_PER_LOAD;
184
244
185
- void * lamport_ptr = (void *) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
245
+ void * lamport_ptr = (void *) &input_ptrs[rank][flags. input_offset + buffer_M * token_dim + current_pos];
186
246
// We have 2 assumptions here:
187
247
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
188
248
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
@@ -198,16 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
198
258
}
199
259
200
260
// Update the buffer flags
201
- if (threadIdx .x == 0 && blockIdx .x == gridDim .x - 1 && blockIdx .y == 0 )
202
- {
203
- // Make sure all blocks have finished reading the offsets, 2-D grid
204
- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
205
- {
206
- }
207
- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
208
- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
209
- *(offset_access_ptr) = 0 ;
210
- }
261
+ flags.wait_and_update (num_tokens);
211
262
}
212
263
}
213
264
@@ -273,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
273
324
default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
274
325
}
275
326
}
327
+ else if (dtype == nvinfer1::DataType::kHALF )
328
+ {
329
+ switch (world_size)
330
+ {
331
+ case 2 : LAUNCH_ALL_REDUCE_KERNEL (2 , __nv_half); break ;
332
+ case 4 : LAUNCH_ALL_REDUCE_KERNEL (4 , __nv_half); break ;
333
+ case 8 : LAUNCH_ALL_REDUCE_KERNEL (8 , __nv_half); break ;
334
+ case 16 : LAUNCH_ALL_REDUCE_KERNEL (16 , __nv_half); break ;
335
+ case 32 : LAUNCH_ALL_REDUCE_KERNEL (32 , __nv_half); break ;
336
+ case 64 : LAUNCH_ALL_REDUCE_KERNEL (64 , __nv_half); break ;
337
+ default : TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported world_size." );
338
+ }
339
+ }
276
340
else
277
341
{
278
342
TLLM_CHECK_WITH_INFO (false , " TwoShot AllReduce]: unsupported dtype." );
279
343
}
280
344
}
281
345
346
+ // Guard for internal helper functions
347
+ namespace
348
+ {
282
349
template <typename T_IN>
283
350
__device__ void copy_f4 (T_IN* dst, T_IN const * src)
284
351
{
@@ -327,6 +394,19 @@ inline __device__ float block_reduce_sum(float val)
327
394
return val;
328
395
}
329
396
397
+ __device__ float4 loadfloat4 (void const * ptr)
398
+ {
399
+
400
+ float4 return_value;
401
+
402
+ asm volatile (" ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n "
403
+ : " =f" (return_value.x ), " =f" (return_value.y ), " =f" (return_value.z ), " =f" (return_value.w )
404
+ : " l" (ptr));
405
+
406
+ return return_value;
407
+ }
408
+ } // namespace
409
+
330
410
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
331
411
__global__ void __launch_bounds__ (128 , 1 )
332
412
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const * buffer_input, T_IN const * gamma, float epsilon,
@@ -353,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)
353
433
354
434
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
355
435
356
- uint32_t * offset_access_ptr = &buffer_flags[3 ];
357
- uint4 flag = reinterpret_cast <uint4 *>(buffer_flags)[0 ];
358
- // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
359
- uint32_t buffer_size = flag.z ;
360
- uint32_t buffer_offset = flag.x * (buffer_size << 1 );
361
- T_IN const * input = &buffer_input[buffer_offset + buffer_size];
436
+ LamportFlags flags (buffer_flags);
437
+ T_IN const * input = &buffer_input[flags.input_offset + flags.buffer_size ];
362
438
363
439
cudaTriggerProgrammaticLaunchCompletion ();
364
440
@@ -388,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
388
464
}
389
465
390
466
__pipeline_commit ();
391
- __syncthreads ();
392
- if (threadIdx .x == 0 )
393
- {
394
- #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
395
- asm volatile (" red.async.release.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
396
- #elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
397
- asm volatile (" red.global.gpu.add.u32 [%0], %1;" ::" l" (offset_access_ptr), " r" (1 ) : " memory" );
398
- #else
399
- atomicAdd (offset_access_ptr, 1 );
400
- #endif
401
- }
467
+ flags.cta_arrive ();
402
468
// Load all inputs
403
469
bool valid = false ;
404
470
@@ -528,16 +594,7 @@ __global__ void __launch_bounds__(128, 1)
528
594
= out4;
529
595
}
530
596
// Update the buffer pointers
531
- if (threadIdx .x == 0 && blockIdx .x == 0 && blockIdx .y == 0 )
532
- {
533
- // Make sure all blocks have finished accessing the buffer
534
- while (*reinterpret_cast <uint32_t volatile *>(offset_access_ptr) < gridDim .x * gridDim .y )
535
- {
536
- }
537
- buffer_flags[0 ] = (flag.x + 1 ) % 3 ;
538
- buffer_flags[1 ] = (flag.y + 1 ) % 3 ;
539
- *(offset_access_ptr) = 0 ;
540
- }
597
+ flags.wait_and_update (batch_size);
541
598
#endif
542
599
}
543
600
@@ -548,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
548
605
549
606
// input to rmsnorm is the buffer in the twoshot ar
550
607
// We should use prenorm output to determine the actual used size
551
- // int batch = normed_output.sizes()[0];
552
- // int dim = normed_output.sizes()[1];
553
608
float _epsilon{static_cast <float >(epsilon)};
554
609
555
610
static constexpr int NUM_THREADS = 128 ;
@@ -612,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
612
667
default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
613
668
}
614
669
}
670
+ else if (dtype == nvinfer1::DataType::kHALF )
671
+ {
672
+ switch (params.hidden_dim )
673
+ {
674
+ case 2048 : LAUNCH_RMSNORM_KERNEL (__nv_half, 2048 ); break ;
675
+ case 4096 : LAUNCH_RMSNORM_KERNEL (__nv_half, 4096 ); break ;
676
+ // Llama-4 Hidden Dimension
677
+ case 5120 : LAUNCH_RMSNORM_KERNEL (__nv_half, 5120 ); break ;
678
+ // DeepSeek Hidden Dimension
679
+ case 7168 : LAUNCH_RMSNORM_KERNEL (__nv_half, 7168 ); break ;
680
+ case 8192 : LAUNCH_RMSNORM_KERNEL (__nv_half, 8192 ); break ;
681
+ default : TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported hidden_dim." );
682
+ }
683
+ }
615
684
else
616
685
{
617
686
TLLM_CHECK_WITH_INFO (false , " [MNNVL TwoShot RMSNorm]: unsupported dtype." );
0 commit comments