Skip to content

Commit 375f74e

Browse files
authored
[fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce and add FP16 support. (#6237)
Signed-off-by: Shiyu Li <[email protected]>
1 parent f8f5ba6 commit 375f74e

File tree

3 files changed

+249
-137
lines changed

3 files changed

+249
-137
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 155 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727

2828
namespace tensorrt_llm::kernels::mnnvl
2929
{
30+
31+
// Guard for internal helper functions
32+
namespace
33+
{
3034
__device__ bool isNegZero(float v)
3135
{
3236
return v == 0.f && signbit(v);
@@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
4953
return __bfloat162float(val);
5054
}
5155

56+
template <>
57+
inline __device__ float toFloat<__nv_half>(__nv_half val)
58+
{
59+
return __half2float(val);
60+
}
61+
5262
template <typename T>
5363
inline __device__ T fromFloat(float val)
5464
{
@@ -61,30 +71,76 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
6171
return __float2bfloat16(val);
6272
}
6373

64-
__device__ float4 loadfloat4(void const* ptr)
74+
template <>
75+
inline __device__ __nv_half fromFloat<__nv_half>(float val)
6576
{
77+
return __float2half(val);
78+
}
6679

67-
float return_value[4];
68-
69-
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
70-
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
71-
: "l"(ptr));
72-
73-
return *(float4*) return_value;
80+
inline __device__ float2 loadfloat2(void const* ptr)
81+
{
82+
float2 return_value;
83+
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(return_value.x), "=f"(return_value.y) : "l"(ptr));
84+
return return_value;
7485
}
7586

76-
__device__ __inline__ float2 loadfloat2(void const* ptr)
87+
template <typename T>
88+
inline __device__ T divUp(T val, T divisor)
7789
{
90+
return (val + divisor - 1) / divisor;
91+
}
7892

79-
float return_value[2];
93+
__device__ struct __attribute__((aligned(32))) LamportFlags
94+
{
95+
uint32_t buffer_size;
96+
uint32_t input_offset;
97+
uint32_t clear_offset;
98+
uint32_t num_tokens_prev;
99+
uint32_t* offset_access_ptr;
100+
uint32_t* buffer_flags;
101+
102+
__device__ explicit LamportFlags(uint32_t* buffer_flags)
103+
: offset_access_ptr(&buffer_flags[4])
104+
, buffer_flags(buffer_flags)
105+
{
106+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
107+
buffer_size = flag.z;
108+
input_offset = flag.x * (buffer_size << 1U);
109+
clear_offset = flag.y * (buffer_size << 1U);
110+
num_tokens_prev = flag.w;
111+
}
80112

81-
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
82-
: "=f"(return_value[0]), "=f"(return_value[1])
83-
: "l"(ptr)
84-
: "memory");
113+
__device__ void cta_arrive()
114+
{
115+
__syncthreads();
116+
if (threadIdx.x == 0)
117+
{
118+
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
119+
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
120+
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
121+
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
122+
#else
123+
atomicAdd(offset_access_ptr, 1);
124+
#endif
125+
}
126+
}
85127

86-
return *(float2*) return_value;
87-
}
128+
__device__ void wait_and_update(uint32_t num_tokens)
129+
{
130+
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
131+
{
132+
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
133+
{
134+
}
135+
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
136+
buffer_flags[0] = (flag.x + 1) % 3;
137+
buffer_flags[1] = (flag.y + 1) % 3;
138+
buffer_flags[3] = num_tokens;
139+
*(offset_access_ptr) = 0;
140+
}
141+
}
142+
};
143+
} // namespace
88144

89145
template <int WORLD_SIZE, typename T>
90146
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
@@ -99,13 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
99155
cudaGridDependencySynchronize();
100156
#endif
101157

102-
// [input_ptr, clear_ptr, buffer_size, access_counter]
103-
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
104-
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
105-
uint32_t buffer_group_size = flag.z << 1;
106-
uint32_t input_offset = flag.x * buffer_group_size;
107-
uint32_t clear_offset = flag.y * buffer_group_size;
108-
uint32_t* offset_access_ptr = &buffer_flags[3];
158+
LamportFlags flags(buffer_flags);
159+
160+
// Capture the number of tokens in previous iteration so that we can properly clear the buffer
161+
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
162+
uint32_t clr_toks_cta
163+
= divUp<uint32_t>(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
164+
* WORLD_SIZE;
165+
clr_toks_cta = divUp<uint32_t>(clr_toks_cta, gridDim.x);
109166

110167
if (elt < token_dim)
111168
{
@@ -115,29 +172,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
115172
T val = shard_ptr[token * token_dim + elt];
116173
if (isNegZero(val))
117174
val = fromFloat<T>(0.f);
118-
input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
175+
input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
176+
= val;
119177

120-
// Reduce and broadcast
178+
// Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
179+
// number of tokens in the current call.
180+
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
181+
{
182+
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
183+
if (clr_token_idx < buffer_M)
184+
{
185+
input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0.f);
186+
}
187+
}
121188

189+
// Reduce and broadcast
122190
if ((token % WORLD_SIZE) == rank)
123191
{
124192
int local_token = token / WORLD_SIZE;
125193
float accum = 0.f;
126194

127195
T values[WORLD_SIZE];
128-
129-
for (int r = 0; r < WORLD_SIZE; r++)
130-
{
131-
input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
132-
= fromFloat<T>(-0.f);
133-
}
134-
135196
while (1)
136197
{
137198
bool valid = true;
138199
for (int r = 0; r < WORLD_SIZE; r++)
139200
{
140-
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
201+
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][flags.input_offset
141202
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
142203
values[r] = *lamport_ptr;
143204
valid &= !isNegZero(values[r]);
@@ -149,40 +210,39 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
149210
{
150211
accum += toFloat<T>(values[r]);
151212
}
152-
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
213+
mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
153214
}
154215
}
155216

156217
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
157218
cudaTriggerProgrammaticLaunchCompletion();
158219
#endif
159220

160-
input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(-0.f);
221+
// Similarly clear broadcast buffer here
222+
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
223+
{
224+
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
225+
if (clr_token_idx < buffer_M)
226+
{
227+
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
228+
= fromFloat<T>(-0.f);
229+
}
230+
}
161231

162232
// Optionally wait for results if the next layer isn't doing the Lamport check
163233
if (wait_for_results)
164234
{
165235
// Update the atomic counter to indicate the block has read the offsets
166-
__syncthreads();
236+
flags.cta_arrive();
167237

168-
if (threadIdx.x == 0)
169-
{
170-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
171-
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
172-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
173-
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
174-
#else
175-
atomicAdd(offset_access_ptr, 1);
176-
#endif
177-
}
178238
// Only use a set of CTAs for lamport sync, reargange the grid
179239
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
180240
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
181241
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
182242
{
183243
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;
184244

185-
void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
245+
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
186246
// We have 2 assumptions here:
187247
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
188248
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
@@ -198,16 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
198258
}
199259

200260
// Update the buffer flags
201-
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
202-
{
203-
// Make sure all blocks have finished reading the offsets, 2-D grid
204-
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
205-
{
206-
}
207-
buffer_flags[0] = (flag.x + 1) % 3;
208-
buffer_flags[1] = (flag.y + 1) % 3;
209-
*(offset_access_ptr) = 0;
210-
}
261+
flags.wait_and_update(num_tokens);
211262
}
212263
}
213264

@@ -273,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
273324
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
274325
}
275326
}
327+
else if (dtype == nvinfer1::DataType::kHALF)
328+
{
329+
switch (world_size)
330+
{
331+
case 2: LAUNCH_ALL_REDUCE_KERNEL(2, __nv_half); break;
332+
case 4: LAUNCH_ALL_REDUCE_KERNEL(4, __nv_half); break;
333+
case 8: LAUNCH_ALL_REDUCE_KERNEL(8, __nv_half); break;
334+
case 16: LAUNCH_ALL_REDUCE_KERNEL(16, __nv_half); break;
335+
case 32: LAUNCH_ALL_REDUCE_KERNEL(32, __nv_half); break;
336+
case 64: LAUNCH_ALL_REDUCE_KERNEL(64, __nv_half); break;
337+
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
338+
}
339+
}
276340
else
277341
{
278342
TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported dtype.");
279343
}
280344
}
281345

346+
// Guard for internal helper functions
347+
namespace
348+
{
282349
template <typename T_IN>
283350
__device__ void copy_f4(T_IN* dst, T_IN const* src)
284351
{
@@ -327,6 +394,19 @@ inline __device__ float block_reduce_sum(float val)
327394
return val;
328395
}
329396

397+
__device__ float4 loadfloat4(void const* ptr)
398+
{
399+
400+
float4 return_value;
401+
402+
asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
403+
: "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), "=f"(return_value.w)
404+
: "l"(ptr));
405+
406+
return return_value;
407+
}
408+
} // namespace
409+
330410
template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
331411
__global__ void __launch_bounds__(128, 1)
332412
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon,
@@ -353,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)
353433

354434
int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];
355435

356-
uint32_t* offset_access_ptr = &buffer_flags[3];
357-
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
358-
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
359-
uint32_t buffer_size = flag.z;
360-
uint32_t buffer_offset = flag.x * (buffer_size << 1);
361-
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
436+
LamportFlags flags(buffer_flags);
437+
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];
362438

363439
cudaTriggerProgrammaticLaunchCompletion();
364440

@@ -388,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
388464
}
389465

390466
__pipeline_commit();
391-
__syncthreads();
392-
if (threadIdx.x == 0)
393-
{
394-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
395-
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
396-
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
397-
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
398-
#else
399-
atomicAdd(offset_access_ptr, 1);
400-
#endif
401-
}
467+
flags.cta_arrive();
402468
// Load all inputs
403469
bool valid = false;
404470

@@ -528,16 +594,7 @@ __global__ void __launch_bounds__(128, 1)
528594
= out4;
529595
}
530596
// Update the buffer pointers
531-
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
532-
{
533-
// Make sure all blocks have finished accessing the buffer
534-
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
535-
{
536-
}
537-
buffer_flags[0] = (flag.x + 1) % 3;
538-
buffer_flags[1] = (flag.y + 1) % 3;
539-
*(offset_access_ptr) = 0;
540-
}
597+
flags.wait_and_update(batch_size);
541598
#endif
542599
}
543600

@@ -548,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons
548605

549606
// input to rmsnorm is the buffer in the twoshot ar
550607
// We should use prenorm output to determine the actual used size
551-
// int batch = normed_output.sizes()[0];
552-
// int dim = normed_output.sizes()[1];
553608
float _epsilon{static_cast<float>(epsilon)};
554609

555610
static constexpr int NUM_THREADS = 128;
@@ -612,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
612667
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
613668
}
614669
}
670+
else if (dtype == nvinfer1::DataType::kHALF)
671+
{
672+
switch (params.hidden_dim)
673+
{
674+
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
675+
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
676+
// Llama-4 Hidden Dimension
677+
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
678+
// DeepSeek Hidden Dimension
679+
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
680+
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
681+
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
682+
}
683+
}
615684
else
616685
{
617686
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");

0 commit comments

Comments
 (0)