From 3e9c1b375f9eed6424f5ff4c1edc49c9fd7707a8 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 12:56:19 -0700 Subject: [PATCH 1/8] Add optimizations to MNNVL kernel. --- flashinfer/comm/mnnvl.py | 77 ++++-- flashinfer/comm/trtllm_mnnvl_ar.py | 4 +- .../comm/trtllm_mnnvl_allreduce.cuh | 221 +++++++++------ tests/test_trtllm_mnnvl_allreduce.py | 253 ++++++++++-------- 4 files changed, 338 insertions(+), 217 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index df49c545c1..3df8010498 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -18,7 +18,7 @@ import platform import sys from dataclasses import dataclass -from typing import List +from typing import List, Optional import pynvml import torch @@ -110,6 +110,26 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: return False +def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: + """ + A helper function that allocates memory on cuda and copies the data from the host to the device. + """ + if not host_ptr_array: + return None + + ArrayType = ctypes.c_uint64 * len(host_ptr_array) + c_array = ArrayType(*host_ptr_array) + size_in_bytes = ctypes.sizeof(c_array) + + device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) + checkCudaErrors( + cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) + ) + # c_array should be freed by GC + + return device_ptr + + class MpiComm: _comm: MPI.Intracomm = MPI.COMM_WORLD @@ -423,7 +443,9 @@ def __init__( # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr self.uc_ptrs: List[int] = [] # std::vector mUcPtrs - self.signal_pads_dev: List[int] = [] # std::vector mSignalPadsDev + self.signal_pads: List[int] = [] # mSignalPads + self.signal_pads_dev = 0 # std::vector mSignalPadsDev + self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle self.uc_handles: List[int] = ( [] @@ -475,14 +497,18 @@ def __init__( raise NotImplementedError("Single-node NVLS allocation not implemented yet") # Initialize signal pads - self.signal_pads_dev = [0] * self.group_size + self.signal_pads = [0] * self.group_size for i in range(self.group_size): - self.signal_pads_dev[i] = self.uc_ptrs[i] + self.signal_pad_offset + self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: checkCudaErrors( - cuda.cuMemsetD8(self.signal_pads_dev[i], 0, self.SIGNAL_PAD_SIZE) + cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) ) + # Create device pointers + self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) + self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs) + def __del__(self): """Destructor - cleanup allocated memory""" @@ -505,6 +531,12 @@ def __del__(self): print(f"Destructor: CUDA context invalid, skipping cleanup: {e}") return + # Free device pointers + if self.signal_pads_dev: + checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev)) + if self.uc_ptrs_dev: + checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev)) + # Unmap UC regions and release their handles if hasattr(self, "uc_handles") and self.uc_handles: for rank in range(self.group_size): @@ -541,14 +573,22 @@ def __del__(self): except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") - def get_signal_pad_ptrs_dev(self) -> List[int]: + def get_signal_pad_ptrs_host(self) -> List[int]: """Get the raw array of signal pad pointers to all ranks (including self)""" - return self.signal_pads_dev + return self.signal_pads - def get_buffer_ptrs_dev(self) -> List[int]: + def get_buffer_ptrs_host(self) -> List[int]: """Get the raw array of unicast pointers to all ranks (including self)""" return self.uc_ptrs + def get_signal_pad_ptrs_dev(self) -> int: + """Get the raw array of signal pad pointers to all ranks (including self)""" + return self.signal_pads_dev + + def get_buffer_ptrs_dev(self) -> int: + """Get the raw array of unicast pointers to all ranks (including self)""" + return self.uc_ptrs_dev + def get_unicast_ptr(self, rank: int) -> int: """Get the raw unicast pointer to a given rank""" if rank >= len(self.uc_ptrs): @@ -842,29 +882,12 @@ def get_multicast_ptr_as_int64(self) -> int: """Get the multicast pointer as int64""" return self.get_multicast_ptr() - def get_buffer_ptrs_dev(self) -> List[int]: + def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() def get_buffer_ptrs_dev_as_int64(self) -> int: """Get the buffer pointers device as int64 (returning first UC pointer)""" - ptrs = self.get_buffer_ptrs_dev() + ptrs = self.mcast_device_memory.get_buffer_ptrs_host() assert ptrs is not None return ptrs[0] if ptrs else 0 - - def get_buffer_ptrs_dev_as_ctypes_ptr(self) -> int: - """ - Get buffer pointers as ctypes array pointer (equivalent to C++ void**). - Returns the address of a ctypes array that can be cast to int64_t and back to void**. - - This matches the C++ pattern: - reinterpret_cast(reinterpret_cast(mUcPtrs.data())) - """ - # Create ctypes array of void pointers - ArrayType = ctypes.c_void_p * len(self.mcast_device_memory.uc_ptrs) - self._buffer_ptrs_array = ArrayType( - *self.mcast_device_memory.uc_ptrs - ) # Keep reference to prevent GC - - # Return the address of this array (equivalent to .data() in C++) - return ctypes.cast(self._buffer_ptrs_array, ctypes.c_void_p).value diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 08fc983be2..d6dd0506e8 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -184,9 +184,9 @@ def get_allreduce_mnnvl_workspace( mpi_barrier() # This is a buffer to maintain the state of this allreduce Op - # [Buffer_ptr, Clear_ptr, Buffer_size, atomic access counter] + # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] buffer_flags = torch.tensor( - [0, 2, max_num_elements, 0], + [0, 2, max_num_elements, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank), ) diff --git a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh index 5a9b88613d..3dbed4b649 100644 --- a/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh +++ b/include/flashinfer/comm/trtllm_mnnvl_allreduce.cuh @@ -63,6 +63,8 @@ __device__ bool isNegZero(float v) { return v == 0.f && signbit(v); } __device__ bool isNegZero(__nv_bfloat16 val) { return isNegZero(__bfloat162float(val)); } +__device__ bool isNegZero(__nv_half val) { return isNegZero(__half2float(val)); } + template inline __device__ float toFloat(T val) { return val; @@ -73,6 +75,11 @@ inline __device__ float toFloat<__nv_bfloat16>(__nv_bfloat16 val) { return __bfloat162float(val); } +template <> +inline __device__ float toFloat<__nv_half>(__nv_half val) { + return __half2float(val); +} + template inline __device__ T fromFloat(float val) { return val; @@ -83,6 +90,68 @@ inline __device__ __nv_bfloat16 fromFloat<__nv_bfloat16>(float val) { return __float2bfloat16(val); } +template <> +inline __device__ __nv_half fromFloat<__nv_half>(float val) { + return __float2half(val); +} + +inline __device__ float2 loadfloat2(void const* ptr) { + float2 return_value; + asm volatile("ld.volatile.global.v2.f32 {%0, %1}, [%2];\n" + : "=f"(return_value.x), "=f"(return_value.y) + : "l"(ptr)); + return return_value; +} + +template +inline __device__ T divUp(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +__device__ struct __attribute__((aligned(32))) LamportFlags { + uint32_t buffer_size; + uint32_t input_offset; + uint32_t clear_offset; + uint32_t num_tokens_prev; + uint32_t* offset_access_ptr; + uint32_t* buffer_flags; + + __device__ explicit LamportFlags(uint32_t* buffer_flags) + : offset_access_ptr(&buffer_flags[4]), buffer_flags(buffer_flags) { + uint4 flag = reinterpret_cast(buffer_flags)[0]; + buffer_size = flag.z; + input_offset = flag.x * (buffer_size << 1U); + clear_offset = flag.y * (buffer_size << 1U); + num_tokens_prev = flag.w; + } + + __device__ void cta_arrive() { + __syncthreads(); + if (threadIdx.x == 0) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)) + asm volatile("red.async.release.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) + : "memory"); +#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("red.global.gpu.add.u32 [%0], %1;" ::"l"(offset_access_ptr), "r"(1) : "memory"); +#else + atomicAdd(offset_access_ptr, 1); +#endif + } + } + + __device__ void wait_and_update(uint32_t num_tokens) { + if (threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == 0) { + while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + } + uint4 flag = reinterpret_cast(buffer_flags)[0]; + buffer_flags[0] = (flag.x + 1) % 3; + buffer_flags[1] = (flag.y + 1) % 3; + buffer_flags[3] = num_tokens; + *(offset_access_ptr) = 0; + } + } +}; + template __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens, int buffer_M, int token_dim, int rank, @@ -96,18 +165,15 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaGridDependencySynchronize(); #endif - uint32_t* offset_access_ptr = &buffer_flags[3]; - // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather - uint32_t buffer_size = (buffer_flags[2] << 1); - uint32_t input_offset = buffer_flags[0] * buffer_size; - uint32_t clear_offset = buffer_flags[1] * buffer_size; + LamportFlags flags(buffer_flags); - if (wait_for_results) { - __syncthreads(); - if (threadIdx.x == 0) { - atomicAdd(offset_access_ptr, 1); - } - } + // Capture the number of tokens in previous iteration so that we can properly clear the buffer + // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up + uint32_t clr_toks_cta = + divUp(flags.num_tokens_prev > num_tokens ? flags.num_tokens_prev : num_tokens, + WORLD_SIZE) * + WORLD_SIZE; + clr_toks_cta = divUp(clr_toks_cta, gridDim.x); if (elt < token_dim) { // Scatter token @@ -115,28 +181,33 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ int dest_token_offset = token / WORLD_SIZE; T val = shard_ptr[token * token_dim + elt]; if (isNegZero(val)) val = fromFloat(0.f); - input_ptrs[dest_rank][input_offset + dest_token_offset * token_dim * WORLD_SIZE + + input_ptrs[dest_rank][flags.input_offset + dest_token_offset * token_dim * WORLD_SIZE + rank * token_dim + elt] = val; - // Reduce and broadcast + // Clear the buffer used by the previous call. Note the number of tokens to clear could be + // larger than the + // number of tokens in the current call. + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) { + input_ptrs[rank][flags.clear_offset + clr_token_idx * token_dim + elt] = fromFloat(-0.f); + } + } - int global_token = token * WORLD_SIZE + rank; - if (global_token < num_tokens) { + // Reduce and broadcast + if ((token % WORLD_SIZE) == rank) { + int local_token = token / WORLD_SIZE; float accum = 0.f; T values[WORLD_SIZE]; - for (int r = 0; r < WORLD_SIZE; r++) { - input_ptrs[rank][clear_offset + token * token_dim * WORLD_SIZE + r * token_dim + elt] = - fromFloat(-0.f); - } - while (1) { bool valid = true; for (int r = 0; r < WORLD_SIZE; r++) { T volatile* lamport_ptr = - (T volatile*)&input_ptrs[rank][input_offset + token * token_dim * WORLD_SIZE + - r * token_dim + elt]; + (T volatile*)&input_ptrs[rank] + [flags.input_offset + local_token * token_dim * WORLD_SIZE + + r * token_dim + elt]; values[r] = *lamport_ptr; valid &= !isNegZero(values[r]); } @@ -145,7 +216,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ for (int r = 0; r < WORLD_SIZE; r++) { accum += toFloat(values[r]); } - mcast_ptr[input_offset + buffer_M * token_dim + global_token * token_dim + elt] = + mcast_ptr[flags.input_offset + buffer_M * token_dim + token * token_dim + elt] = fromFloat(accum); } } @@ -154,26 +225,43 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaTriggerProgrammaticLaunchCompletion(); #endif - input_ptrs[rank][clear_offset + buffer_M * token_dim + token * token_dim + elt] = - fromFloat(-0.f); + // Similarly clear broadcast buffer here + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) { + input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + + elt] = fromFloat(-0.f); + } + } // Optionally wait for results if the next layer isn't doing the Lamport check if (wait_for_results) { - T volatile* lamport_ptr = (T volatile*)&input_ptrs[rank][input_offset + buffer_M * token_dim + - token * token_dim + elt]; - T val = *lamport_ptr; - while (isNegZero(val)) val = *lamport_ptr; - - // Copy if requested - if (output_ptr) output_ptr[token * token_dim + elt] = val; - if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { - // Make sure all blocks have finished reading the offsets, 2-D grid - while (*reinterpret_cast(offset_access_ptr) < gridDim.x * gridDim.y) { + // Update the atomic counter to indicate the block has read the offsets + flags.cta_arrive(); + // Only use a set of CTAs for lamport sync, reargange the grid + constexpr int ELTS_PER_LOAD = sizeof(float2) / sizeof(T); + // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) + if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { + uint64_t current_pos = + blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; + + void* lamport_ptr = + (void*)&input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; + // We have 2 assumptions here: + // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be + // aligned to 8B + // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) + float2 val = loadfloat2(lamport_ptr); + while (isNegZero(*(T*)&val)) { + val = loadfloat2(lamport_ptr); + } + if (output_ptr) { + *((float2*)&output_ptr[current_pos]) = val; } - buffer_flags[0] = (buffer_flags[0] + 1) % 3; - buffer_flags[1] = (buffer_flags[1] + 1) % 3; - *(offset_access_ptr) = 0; } + + // Update the buffer flags + flags.wait_and_update(num_tokens); } } @@ -247,23 +335,23 @@ __device__ float4 loadfloat4(void const* ptr) { // Check alignment - ptr should be 16-byte aligned for safe float4 load if (reinterpret_cast(ptr) % 16 != 0) { // Fall back to scalar loads if not aligned - float return_value[4]; + float4 return_value; float const* float_ptr = reinterpret_cast(ptr); - return_value[0] = float_ptr[0]; - return_value[1] = float_ptr[1]; - return_value[2] = float_ptr[2]; - return_value[3] = float_ptr[3]; - return *(float4*)return_value; + return_value.x = float_ptr[0]; + return_value.y = float_ptr[1]; + return_value.z = float_ptr[2]; + return_value.w = float_ptr[3]; + return return_value; } - float return_value[4]; + float4 return_value; asm volatile("ld.volatile.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" - : "=f"(return_value[0]), "=f"(return_value[1]), "=f"(return_value[2]), - "=f"(return_value[3]) + : "=f"(return_value.x), "=f"(return_value.y), "=f"(return_value.z), + "=f"(return_value.w) : "l"(ptr)); - return *(float4*)return_value; + return return_value; } // Safer version that checks bounds before loading @@ -351,21 +439,13 @@ __global__ void __launch_bounds__(128, 1) int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; - uint32_t* offset_access_ptr = &buffer_flags[3]; - // Buffer size is M * N, and we need two buffers for reduce-scatter and allgather - uint32_t buffer_size = buffer_flags[2]; - uint32_t buffer_offset = buffer_flags[0] * (buffer_size << 1); - T_IN const* input = &buffer_input[buffer_offset + buffer_size]; + LamportFlags flags(buffer_flags); + T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - __syncthreads(); - if (threadIdx.x == 0) { - atomicAdd(offset_access_ptr, 1); - } - for (int i = 0; i < NUM_INPUTS; i++) { for (int j = 0; j < DIM / (1 * ELTS_PER_THREAD * NUM_THREADS); j++) { int k = j * NUM_THREADS + threadIdx.x; @@ -390,6 +470,7 @@ __global__ void __launch_bounds__(128, 1) } __pipeline_commit(); + flags.cta_arrive(); // Load all inputs bool valid = false; @@ -414,25 +495,19 @@ __global__ void __launch_bounds__(128, 1) // So the actual pointer we're accessing is: input + element_offset // Which equals: &buffer_input[buffer_offset + buffer_size + element_offset] - // Calculate the total buffer size in elements - int total_buffer_elements = (buffer_size << 1) / sizeof(T_IN); // Two buffers worth - - // The maximum valid element index relative to the input pointer - int max_valid_element_index = total_buffer_elements - buffer_size / sizeof(T_IN); - float4* src4 = (float4*)&input[element_offset]; float4 value; // Check if we have enough elements remaining for a safe float4 load - if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= max_valid_element_index) { + if (element_offset >= 0 && element_offset + ELTS_PER_THREAD <= flags.buffer_size) { value = loadfloat4(src4); } else { // Use safe load for boundary cases or out-of-bounds - int remaining_elements = max_valid_element_index - element_offset; + int remaining_elements = flags.buffer_size - element_offset; if (remaining_elements <= 0) { // Completely out of bounds, return zeros - float return_value[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - value = *(float4*)return_value; + float4 return_value = {0.0f, 0.0f, 0.0f, 0.0f}; + value = return_value; } else { value = loadfloat4_safe(reinterpret_cast(src4), remaining_elements); } @@ -537,15 +612,7 @@ __global__ void __launch_bounds__(128, 1) threadIdx.x * ELTS_PER_THREAD] = out4; } // Update the buffer pointers - if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { - // Make sure all blocks have finished accessing the buffer - while (*reinterpret_cast(offset_access_ptr) != gridDim.x * gridDim.y) { - } - buffer_flags[0] = (buffer_flags[0] + 1) % 3; - buffer_flags[1] = (buffer_flags[1] + 1) % 3; - *(offset_access_ptr) = 0; - } - __syncthreads(); + flags.wait_and_update(batch_size); #endif } diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 3117321c78..096aea1c28 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -24,10 +24,14 @@ def row_linear_residual_norm_fusion_forward( eps: float, hidden_size: int, dtype: torch.dtype, - tensor_parallel_size: int, - tensor_parallel_rank: int, + mapping: Mapping, fusion: bool, reference_output: tuple[torch.Tensor, ...], + multicast_ptr: int, + buffer_ptrs_dev: torch.Tensor, + unicast_ptr: int, + max_num_elements_mnnvl: int, + buffer_flags_mnnvl: torch.Tensor, ): x = x.cuda() @@ -35,13 +39,10 @@ def row_linear_residual_norm_fusion_forward( norm_weight = norm_weight.cuda() reference_output = tuple(t.cuda() for t in reference_output) - MPI.COMM_WORLD.barrier() + tensor_parallel_size = mapping.tp_size + tensor_parallel_rank = mapping.tp_rank - mapping = Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=tensor_parallel_rank, - ) + MPI.COMM_WORLD.barrier() def func( input, @@ -57,8 +58,6 @@ def func( # For both fused and unfused cases: shape = input.shape - hidden_size = shape[-1] - assert max_num_elements_mnnvl % hidden_size == 0 input = input.view(-1, shape[-1]) @@ -109,78 +108,75 @@ def func( ) return (output.view(shape),) - # Get workspace buffers using MPI rank - mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( - trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) + output = func( + x.clone(), + residual.clone(), + norm_weight, + eps, + fusion, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, ) - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr_as_int64() - buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev_as_ctypes_ptr() - unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( - tensor_parallel_rank - ) + assert output[0].shape == reference_output[0].shape - try: - output = func( - x.clone(), - residual.clone(), - norm_weight, - eps, - fusion, - multicast_ptr, - buffer_ptrs_dev, - unicast_ptr, - max_num_elements_mnnvl, + if tensor_parallel_rank == 0: + print("output[0] (first 10 values):", output[0].flatten()[:10]) + print( + "reference_output[0] (first 10 values):", + reference_output[0].flatten()[:10], ) - assert output[0].shape == reference_output[0].shape - - if tensor_parallel_rank == 0: - print("output[0] (first 10 values):", output[0].flatten()[:10]) + if fusion: + print("output[1] (first 10 values):", output[1].flatten()[:10]) print( - "reference_output[0] (first 10 values):", - reference_output[0].flatten()[:10], + "reference_output[1] (first 10 values):", + reference_output[1].flatten()[:10], ) - if fusion: - print("output[1] (first 10 values):", output[1].flatten()[:10]) - print( - "reference_output[1] (first 10 values):", - reference_output[1].flatten()[:10], - ) + torch.testing.assert_close( + output[0], + reference_output[0], + rtol=0.05, + atol=0.15, + ) + if fusion: torch.testing.assert_close( - output[0], - reference_output[0], + output[1], + reference_output[1], rtol=0.05, atol=0.15, ) - if fusion: - torch.testing.assert_close( - output[1], - reference_output[1], - rtol=0.05, - atol=0.15, - ) - - finally: - # Ensure cleanup happens even if assertions fail - del mcast_buffer_mnnvl - """Main test function that runs on each MPI rank""" # seq_lens = [1, 4, 32, 128] -@pytest.mark.parametrize("seq_len", [4]) +@pytest.mark.parametrize( + "seq_lens", + [ + [1], + [4], + [15], + [27, 11, 24], + [127], + ], +) # Test with different sequence length lists @pytest.mark.parametrize("fusion", [False, True]) -def test_mnnvl_allreduce_full(monkeypatch, seq_len: int, fusion: bool): +def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. # Get MPI info rank = MPI.COMM_WORLD.Get_rank() world_size = MPI.COMM_WORLD.Get_size() + gpus_per_node = torch.cuda.device_count() + + if gpus_per_node == 0: + pytest.skip("MNNVL allreduce test requires at least one CUDA device per node") # Ensure we have exactly 2 ranks for this test if world_size < 2: @@ -188,12 +184,23 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_len: int, fusion: bool): print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}") sys.exit(1) + mapping = Mapping( + world_size=world_size, + rank=rank, + gpus_per_node=gpus_per_node, + tp_size=world_size, + ) + # Set CUDA device based on rank - torch.cuda.set_device(rank) + torch.cuda.set_device(mapping.local_rank) - if rank == 0: - print(f"Running MNNVL AllReduce test with {world_size} ranks") - print(f"Rank {rank} using GPU {torch.cuda.current_device()}") + if mapping.local_rank == 0: + print( + f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks" + ) + print( + f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" + ) hidden_size = 7168 dtype = torch.bfloat16 @@ -207,68 +214,87 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_len: int, fusion: bool): failure_message = "" try: - if rank == 0: - print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}" - ) - - # Generate test data (same on all ranks due to same seed) - x_full = torch.randn( - (tensor_parallel_size, seq_len, hidden_size), - dtype=dtype, - device=torch.device("cuda"), - ) - residual = torch.randn( - (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") - ) - norm_weight = torch.randn( - (hidden_size,), dtype=dtype, device=torch.device("cuda") + # Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list + # This workspace is sized for the maximum expected sequence length and can be reused within each list + # Each parameterized list gets its own fresh workspace allocation + mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = ( + trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) ) - # Each rank gets its slice of the input - x = x_full[rank, :, :] + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr_as_int64() + buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() + unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( + mapping.tp_rank + ) - # Compute reference output based on fusion mode - if fusion: - # Fused case: AllReduce + Residual Add + RMS Norm - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - residual_out = allreduce_result + residual # Add residual - print( - "Device of residual_out:{}, norm_weight:{}".format( - residual_out.device, norm_weight.device + # Test each sequence length with the same workspace (reusing allocated buffers within this list) + for seq_len in seq_lens: + if rank == 0: + print( + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}" ) + + # Generate test data (same on all ranks due to same seed) + x_full = torch.randn( + (tensor_parallel_size, seq_len, hidden_size), + dtype=dtype, + device=torch.device("cuda"), + ) + residual = torch.randn( + (seq_len, hidden_size), dtype=dtype, device=torch.device("cuda") + ) + norm_weight = torch.randn( + (hidden_size,), dtype=dtype, device=torch.device("cuda") ) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - reference_output = (norm_out, residual_out) - else: - # Non-fused case: Only AllReduce - allreduce_result = torch.sum(x_full, dim=0) # AllReduce result - reference_output = (allreduce_result,) - - # Run the test - row_linear_residual_norm_fusion_forward( - x, - residual, - norm_weight, - eps, - hidden_size, - dtype, - tensor_parallel_size, - rank, - fusion, - reference_output, - ) + # Each rank gets its slice of the input + x = x_full[rank, :, :] - # Synchronize before next test - trtllm_mnnvl_ar.mpi_barrier() + # Compute reference output based on fusion mode + if fusion: + # Fused case: AllReduce + Residual Add + RMS Norm + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + residual_out = allreduce_result + residual # Add residual + print( + "Device of residual_out:{}, norm_weight:{}".format( + residual_out.device, norm_weight.device + ) + ) + norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) - print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}") + reference_output = (norm_out, residual_out) + else: + # Non-fused case: Only AllReduce + allreduce_result = torch.sum(x_full, dim=0) # AllReduce result + reference_output = (allreduce_result,) + + # Run the test with the same workspace + row_linear_residual_norm_fusion_forward( + x, + residual, + norm_weight, + eps, + hidden_size, + dtype, + mapping, + fusion, + reference_output, + multicast_ptr, + buffer_ptrs_dev, + unicast_ptr, + max_num_elements_mnnvl, + buffer_flags_mnnvl, + ) + + # Synchronize before next test + trtllm_mnnvl_ar.mpi_barrier() + + print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}") except Exception as e: rank_failed = True failure_message = ( - f"FAILED[rank={rank}]: seq_len={seq_len}, fusion={fusion} failed: {e}" + f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion} failed: {e}" ) print(failure_message) # Gather failure status from all ranks @@ -284,5 +310,10 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_len: int, fusion: bool): pytest.fail(f"Test failed on ranks {failed_ranks}") trtllm_mnnvl_ar.mpi_barrier() + finally: + # Ensure cleanup happens for this list's workspace + if "mcast_buffer_mnnvl" in locals(): + del mcast_buffer_mnnvl + # Final synchronization and check for failures across all ranks trtllm_mnnvl_ar.mpi_barrier() From fe588f74df30e7eedba86f726655efbebfb69458 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 12:58:21 -0700 Subject: [PATCH 2/8] Fix wrong type hint. --- tests/test_trtllm_mnnvl_allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 096aea1c28..b219ee52a7 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -28,7 +28,7 @@ def row_linear_residual_norm_fusion_forward( fusion: bool, reference_output: tuple[torch.Tensor, ...], multicast_ptr: int, - buffer_ptrs_dev: torch.Tensor, + buffer_ptrs_dev: int, unicast_ptr: int, max_num_elements_mnnvl: int, buffer_flags_mnnvl: torch.Tensor, From 7cae81e25f89434c62836a170f0b75e0e11e5dba Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 13:51:03 -0700 Subject: [PATCH 3/8] Address code review issue for alignment. --- csrc/trtllm_mnnvl_allreduce.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index f84f336b7c..bb97b5de8f 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -37,6 +37,8 @@ void trtllm_mnnvl_all_reduce(at::Tensor& in, int64_t multicast_buffer_ptr, int64 int64_t token_dim = in.size(1); // Validate input parameters + TORCH_CHECK(token_dim % (sizeof(float2) / sizeof(c_type)) == 0, + "token_dim must be divisible by ", sizeof(float2) / sizeof(c_type)); TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks); TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank); TORCH_CHECK(out.has_value() || !wait_for_results, From 19e4ca41f48df4e1586d857ed9163223b96b9f30 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 13:54:42 -0700 Subject: [PATCH 4/8] Remove unused functions. --- flashinfer/comm/mnnvl.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 3df8010498..75a060662e 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -795,16 +795,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) ) - def get_multicast_ptr_as_int64(self) -> int: - """Get multicast pointer as int64 (legacy compatibility)""" - return self.get_multicast_ptr() - - def get_buffer_ptrs_dev_as_int64(self) -> int: - """Get buffer pointers device as int64 (returning first UC pointer for now) (legacy compatibility)""" - return self.uc_ptrs[0] if self.uc_ptrs else 0 - def lamport_initialize(self, rank: int, dtype: torch.dtype): - if dtype == torch.bfloat16: + if dtype == torch.bfloat16 or dtype == torch.float16: neg_zero = 0x8000 dsize = 2 memset_func = cuda.cuMemsetD16 @@ -878,16 +870,6 @@ def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() - def get_multicast_ptr_as_int64(self) -> int: - """Get the multicast pointer as int64""" - return self.get_multicast_ptr() - def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev() - - def get_buffer_ptrs_dev_as_int64(self) -> int: - """Get the buffer pointers device as int64 (returning first UC pointer)""" - ptrs = self.mcast_device_memory.get_buffer_ptrs_host() - assert ptrs is not None - return ptrs[0] if ptrs else 0 From fbb5689cfadb25fb7ee4a39e9eca694bf8e6f12d Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 13:58:53 -0700 Subject: [PATCH 5/8] Fix wrong API call. --- tests/test_trtllm_mnnvl_allreduce.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index b219ee52a7..314a7d3915 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -221,7 +221,7 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype) ) - multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr_as_int64() + multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr() buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev() unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr( mapping.tp_rank From 02f96c1a5c9caec05bc72e7746d61ee267bee278 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 14:56:39 -0700 Subject: [PATCH 6/8] Add test cases --- tests/test_trtllm_mnnvl_allreduce.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 314a7d3915..229807d787 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -155,7 +155,6 @@ def func( """Main test function that runs on each MPI rank""" -# seq_lens = [1, 4, 32, 128] @pytest.mark.parametrize( "seq_lens", [ @@ -167,7 +166,11 @@ def func( ], ) # Test with different sequence length lists @pytest.mark.parametrize("fusion", [False, True]) -def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) +def test_mnnvl_allreduce_full( + monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int +): monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. # Get MPI info @@ -202,11 +205,8 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}" ) - hidden_size = 7168 - dtype = torch.bfloat16 tensor_parallel_size = world_size eps = 1e-5 - torch.manual_seed(42) # Track if this rank failed @@ -231,7 +231,7 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): for seq_len in seq_lens: if rank == 0: print( - f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}" + f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}" ) # Generate test data (same on all ranks due to same seed) @@ -289,13 +289,13 @@ def test_mnnvl_allreduce_full(monkeypatch, seq_lens: list[int], fusion: bool): # Synchronize before next test trtllm_mnnvl_ar.mpi_barrier() - print(f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}") + print( + f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" + ) except Exception as e: rank_failed = True - failure_message = ( - f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion} failed: {e}" - ) + failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) # Gather failure status from all ranks all_failures = MPI.COMM_WORLD.allgather(rank_failed) From 092a933464162e3ca153dbe604aedc3d8d176f5f Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 15:09:35 -0700 Subject: [PATCH 7/8] Add reference rmsnorm for fp32 --- tests/test_trtllm_mnnvl_allreduce.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 229807d787..252ac334bd 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -16,6 +16,18 @@ from flashinfer.norm import rmsnorm +# Fallback rmsnorm for FP32 +def _rmsnorm_fp32( + input: torch.Tensor, weight: torch.Tensor, eps: float +) -> torch.Tensor: + return torch.nn.functional.rms_norm( + input, + [input.shape[-1]], + weight, + eps, + ) + + @torch.inference_mode() def row_linear_residual_norm_fusion_forward( x: torch.Tensor, @@ -260,7 +272,10 @@ def test_mnnvl_allreduce_full( residual_out.device, norm_weight.device ) ) - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) + if not dtype in [torch.bfloat16, torch.float16]: + norm_out = _rmsnorm_fp32(residual_out, norm_weight, eps) + else: + norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) reference_output = (norm_out, residual_out) else: From 2ad7c76f872dd37c70932e8cce047ed4a5b48231 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 24 Jul 2025 15:14:52 -0700 Subject: [PATCH 8/8] Bypass fp32 --- tests/test_trtllm_mnnvl_allreduce.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 252ac334bd..860da8535a 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -16,18 +16,6 @@ from flashinfer.norm import rmsnorm -# Fallback rmsnorm for FP32 -def _rmsnorm_fp32( - input: torch.Tensor, weight: torch.Tensor, eps: float -) -> torch.Tensor: - return torch.nn.functional.rms_norm( - input, - [input.shape[-1]], - weight, - eps, - ) - - @torch.inference_mode() def row_linear_residual_norm_fusion_forward( x: torch.Tensor, @@ -178,7 +166,7 @@ def func( ], ) # Test with different sequence length lists @pytest.mark.parametrize("fusion", [False, True]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192]) def test_mnnvl_allreduce_full( monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int @@ -272,10 +260,7 @@ def test_mnnvl_allreduce_full( residual_out.device, norm_weight.device ) ) - if not dtype in [torch.bfloat16, torch.float16]: - norm_out = _rmsnorm_fp32(residual_out, norm_weight, eps) - else: - norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) + norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False) reference_output = (norm_out, residual_out) else: