Skip to content

[fix][nvbugs/5399355] Fix Lamport buffer clear issue for MNNVL TwoShot Allreduce and add FP16 support. #6237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@

namespace tensorrt_llm::kernels::mnnvl
{

// Guard for internal helper functions
namespace
{
__device__ bool isNegZero(float v)
{
return v == 0.f && signbit(v);
Expand All @@ -49,6 +53,12 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val)
return __bfloat162float(val);
}

template <>
inline __device__ float toFloat<__nv_half>(__nv_half val)
{
return __half2float(val);
}

template <typename T>
inline __device__ T fromFloat(float val)
{
Expand All @@ -61,30 +71,76 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val)
return __float2bfloat16(val);
}

__device__ float4 loadfloat4(void const* ptr)
template <>
inline __device__ __nv_half fromFloat<__nv_half>(float val)
{
return __float2half(val);
}

float return_value[4];

asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
: "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), "=f"(return_value[3])
: "l"(ptr));

return *(float4*) return_value;
inline __device__ float2 loadfloat2(void const* ptr)
{
float2 return_value;
asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(return_value.x), "=f"(return_value.y) : "l"(ptr));
return return_value;
}

__device__ __inline__ float2 loadfloat2(void const* ptr)
template <typename T>
inline __device__ T divUp(T val, T divisor)
{
return (val + divisor - 1) / divisor;
}

float return_value[2];
__device__ struct __attribute__((aligned(32))) LamportFlags
{
uint32_t buffer_size;
uint32_t input_offset;
uint32_t clear_offset;
uint32_t num_tokens_prev;
uint32_t* offset_access_ptr;
uint32_t* buffer_flags;

__device__ explicit LamportFlags(uint32_t* buffer_flags)
: offset_access_ptr(&buffer_flags[4])
, buffer_flags(buffer_flags)
{
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
buffer_size = flag.z;
input_offset = flag.x * (buffer_size << 1U);
clear_offset = flag.y * (buffer_size << 1U);
num_tokens_prev = flag.w;
}

asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n"
: "=f"(return_value[0]), "=f"(return_value[1])
: "l"(ptr)
: "memory");
__device__ void cta_arrive()
{
__syncthreads();
if (threadIdx.x == 0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#else
atomicAdd(offset_access_ptr, 1);
#endif
}
}

return *(float2*) return_value;
}
__device__ void wait_and_update(uint32_t num_tokens)
{
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
{
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
{
}
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
buffer_flags[3] = num_tokens;
*(offset_access_ptr) = 0;
}
}
};
} // namespace

template <int WORLD_SIZE, typename T>
__global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens,
Expand All @@ -99,13 +155,14 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
cudaGridDependencySynchronize();
#endif

// [input_ptr, clear_ptr, buffer_size, access_counter]
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
// Each buffer is M * N and we have 2 buffers in each group, one for reduce-scatter and one for allgather
uint32_t buffer_group_size = flag.z << 1;
uint32_t input_offset = flag.x * buffer_group_size;
uint32_t clear_offset = flag.y * buffer_group_size;
uint32_t* offset_access_ptr = &buffer_flags[3];
LamportFlags flags(buffer_flags);

// Capture the number of tokens in previous iteration so that we can properly clear the buffer
// The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up
uint32_t clr_toks_cta
= divUp<uint32_t>(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, WORLD_SIZE)
* WORLD_SIZE;
clr_toks_cta = divUp<uint32_t>(clr_toks_cta, gridDim.x);

if (elt < token_dim)
{
Expand All @@ -115,29 +172,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
T val = shard_ptr[token * token_dim + elt];
if (isNegZero(val))
val = fromFloat<T>(0.f);
input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val;
input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt]
= val;

// Reduce and broadcast
// Clear the buffer used by the previous call. Note the number of tokens to clear could be larger than the
// number of tokens in the current call.
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
{
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
if (clr_token_idx < buffer_M)
{
input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat<T>(-0.f);
}
}

// Reduce and broadcast
if ((token % WORLD_SIZE) == rank)
{
int local_token = token / WORLD_SIZE;
float accum = 0.f;

T values[WORLD_SIZE];

for (int r = 0; r < WORLD_SIZE; r++)
{
input_ptrs[rank][clear_offset + local_token * token_dim * WORLD_SIZE + r * token_dim + elt]
= fromFloat<T>(-0.f);
}

while (1)
{
bool valid = true;
for (int r = 0; r < WORLD_SIZE; r++)
{
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][input_offset
T volatile* lamport_ptr = (T volatile*) &input_ptrs[rank][flags.input_offset
+ local_token * token_dim * WORLD_SIZE + r * token_dim + elt];
values[r] = *lamport_ptr;
valid &= !isNegZero(values[r]);
Expand All @@ -149,40 +210,39 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
{
accum += toFloat<T>(values[r]);
}
mcast_ptr[input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(accum);
}
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif

input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat<T>(-0.f);
// Similarly clear broadcast buffer here
for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++)
{
uint32_t clr_token_idx = token + clr_tok * gridDim.x;
if (clr_token_idx < buffer_M)
{
input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt]
= fromFloat<T>(-0.f);
}
}

// Optionally wait for results if the next layer isn't doing the Lamport check
if (wait_for_results)
{
// Update the atomic counter to indicate the block has read the offsets
__syncthreads();
flags.cta_arrive();

if (threadIdx.x == 0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#else
atomicAdd(offset_access_ptr, 1);
#endif
}
// Only use a set of CTAs for lamport sync, reargange the grid
constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T);
// blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32)
if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD))
{
uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD;

void* lamport_ptr = (void*) &input_ptrs[rank][input_offset + buffer_M * token_dim + current_pos];
void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos];
// We have 2 assumptions here:
// 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B
// 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32)
Expand All @@ -198,16 +258,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_
}

// Update the buffer flags
if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0)
{
// Make sure all blocks have finished reading the offsets, 2-D grid
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
{
}
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
*(offset_access_ptr) = 0;
}
flags.wait_and_update(num_tokens);
}
}

Expand Down Expand Up @@ -273,12 +324,28 @@ void twoshot_allreduce_op(AllReduceParams const& params)
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
}
}
else if (dtype == nvinfer1::DataType::kHALF)
{
switch (world_size)
{
case 2: LAUNCH_ALL_REDUCE_KERNEL(2, __nv_half); break;
case 4: LAUNCH_ALL_REDUCE_KERNEL(4, __nv_half); break;
case 8: LAUNCH_ALL_REDUCE_KERNEL(8, __nv_half); break;
case 16: LAUNCH_ALL_REDUCE_KERNEL(16, __nv_half); break;
case 32: LAUNCH_ALL_REDUCE_KERNEL(32, __nv_half); break;
case 64: LAUNCH_ALL_REDUCE_KERNEL(64, __nv_half); break;
default: TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported world_size.");
}
}
else
{
TLLM_CHECK_WITH_INFO(false, "TwoShot AllReduce]: unsupported dtype.");
}
}

// Guard for internal helper functions
namespace
{
template <typename T_IN>
__device__ void copy_f4(T_IN* dst, T_IN const* src)
{
Expand Down Expand Up @@ -327,6 +394,19 @@ inline __device__ float block_reduce_sum(float val)
return val;
}

__device__ float4 loadfloat4(void const* ptr)
{

float4 return_value;

asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n"
: "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), "=f"(return_value.w)
: "l"(ptr));

return return_value;
}
} // namespace

template <int DIM, int NUM_THREADS, int NUM_INPUTS, typename T_OUT, typename T_IN>
__global__ void __launch_bounds__(128, 1)
RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon,
Expand All @@ -353,12 +433,8 @@ __global__ void __launch_bounds__(128, 1)

int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)];

uint32_t* offset_access_ptr = &buffer_flags[3];
uint4 flag = reinterpret_cast<uint4*>(buffer_flags)[0];
// Buffer size is M * N, and we need two buffers for reduce-scatter and allgather
uint32_t buffer_size = flag.z;
uint32_t buffer_offset = flag.x * (buffer_size << 1);
T_IN const* input = &buffer_input[buffer_offset + buffer_size];
LamportFlags flags(buffer_flags);
T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size];

cudaTriggerProgrammaticLaunchCompletion();

Expand Down Expand Up @@ -388,17 +464,7 @@ __global__ void __launch_bounds__(128, 1)
}

__pipeline_commit();
__syncthreads();
if (threadIdx.x == 0)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000))
asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory");
#else
atomicAdd(offset_access_ptr, 1);
#endif
}
flags.cta_arrive();
// Load all inputs
bool valid = false;

Expand Down Expand Up @@ -528,16 +594,7 @@ __global__ void __launch_bounds__(128, 1)
= out4;
}
// Update the buffer pointers
if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
{
// Make sure all blocks have finished accessing the buffer
while (*reinterpret_cast<uint32_t volatile*>(offset_access_ptr) < gridDim.x * gridDim.y)
{
}
buffer_flags[0] = (flag.x + 1) % 3;
buffer_flags[1] = (flag.y + 1) % 3;
*(offset_access_ptr) = 0;
}
flags.wait_and_update(batch_size);
#endif
}

Expand All @@ -548,8 +605,6 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons

// input to rmsnorm is the buffer in the twoshot ar
// We should use prenorm output to determine the actual used size
// int batch = normed_output.sizes()[0];
// int dim = normed_output.sizes()[1];
float _epsilon{static_cast<float>(epsilon)};

static constexpr int NUM_THREADS = 128;
Expand Down Expand Up @@ -612,6 +667,20 @@ void twoshot_rmsnorm_op(RMSNormParams const& params)
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
}
}
else if (dtype == nvinfer1::DataType::kHALF)
{
switch (params.hidden_dim)
{
case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break;
case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break;
// Llama-4 Hidden Dimension
case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break;
// DeepSeek Hidden Dimension
case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break;
case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break;
default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim.");
}
}
else
{
TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype.");
Expand Down
Loading