Skip to content

Commit 43e08e9

Browse files
authored
comm: Optimizations for TRTLLM MNNVL Allreduce (#1321)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR introduces a series of optimizations to the trtllm_mnnvl_allreduce. These optimizations are also added by [https://github.com/NVIDIA/TensorRT-LLM/pull/5934](https://github.com/NVIDIA/TensorRT-LLM/pull/5934) and [https://github.com/NVIDIA/TensorRT-LLM/pull/6237](https://github.com/NVIDIA/TensorRT-LLM/pull/6237)。 - Use GPU array to pass the uc pointers in the mcast memory. - Use L2 reduction to replace the expensive atomicAdd. - Adjust the point of synchronization for buffer flag read. - Optimize the lamport polling performance. - Clean up the code structure. - Enhance the unittest to cover more test cases. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7d2a1c7 commit 43e08e9

File tree

5 files changed

+346
-241
lines changed

5 files changed

+346
-241
lines changed

csrc/trtllm_mnnvl_allreduce.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void trtllm_mnnvl_all_reduce(at::Tensor& in, int64_t multicast_buffer_ptr, int64
3737
int64_t token_dim = in.size(1);
3838

3939
// Validate input parameters
40+
TORCH_CHECK(token_dim % (sizeof(float2) / sizeof(c_type)) == 0,
41+
"token_dim must be divisible by ", sizeof(float2) / sizeof(c_type));
4042
TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks);
4143
TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank);
4244
TORCH_CHECK(out.has_value() || !wait_for_results,

flashinfer/comm/mnnvl.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import platform
1919
import sys
2020
from dataclasses import dataclass
21-
from typing import List
21+
from typing import List, Optional
2222

2323
import pynvml
2424
import torch
@@ -110,6 +110,26 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool:
110110
return False
111111

112112

113+
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
114+
"""
115+
A helper function that allocates memory on cuda and copies the data from the host to the device.
116+
"""
117+
if not host_ptr_array:
118+
return None
119+
120+
ArrayType = ctypes.c_uint64 * len(host_ptr_array)
121+
c_array = ArrayType(*host_ptr_array)
122+
size_in_bytes = ctypes.sizeof(c_array)
123+
124+
device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
125+
checkCudaErrors(
126+
cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)
127+
)
128+
# c_array should be freed by GC
129+
130+
return device_ptr
131+
132+
113133
class MpiComm:
114134
_comm: MPI.Intracomm = MPI.COMM_WORLD
115135

@@ -423,7 +443,9 @@ def __init__(
423443
# CUDA memory handles and pointers
424444
self.mc_ptr = 0 # CUdeviceptr mMcPtr
425445
self.uc_ptrs: List[int] = [] # std::vector<CUdeviceptr> mUcPtrs
426-
self.signal_pads_dev: List[int] = [] # std::vector<CUdeviceptr> mSignalPadsDev
446+
self.signal_pads: List[int] = [] # mSignalPads
447+
self.signal_pads_dev = 0 # std::vector<CUdeviceptr> mSignalPadsDev
448+
self.uc_ptrs_dev = 0
427449
self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle
428450
self.uc_handles: List[int] = (
429451
[]
@@ -475,14 +497,18 @@ def __init__(
475497
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
476498

477499
# Initialize signal pads
478-
self.signal_pads_dev = [0] * self.group_size
500+
self.signal_pads = [0] * self.group_size
479501
for i in range(self.group_size):
480-
self.signal_pads_dev[i] = self.uc_ptrs[i] + self.signal_pad_offset
502+
self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset
481503
if i == self.group_rank:
482504
checkCudaErrors(
483-
cuda.cuMemsetD8(self.signal_pads_dev[i], 0, self.SIGNAL_PAD_SIZE)
505+
cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)
484506
)
485507

508+
# Create device pointers
509+
self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads)
510+
self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs)
511+
486512
def __del__(self):
487513
"""Destructor - cleanup allocated memory"""
488514

@@ -505,6 +531,12 @@ def __del__(self):
505531
print(f"Destructor: CUDA context invalid, skipping cleanup: {e}")
506532
return
507533

534+
# Free device pointers
535+
if self.signal_pads_dev:
536+
checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev))
537+
if self.uc_ptrs_dev:
538+
checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev))
539+
508540
# Unmap UC regions and release their handles
509541
if hasattr(self, "uc_handles") and self.uc_handles:
510542
for rank in range(self.group_size):
@@ -541,14 +573,22 @@ def __del__(self):
541573
except Exception as e:
542574
print(f"Destructor: Failed to release MC handle: {e}")
543575

544-
def get_signal_pad_ptrs_dev(self) -> List[int]:
576+
def get_signal_pad_ptrs_host(self) -> List[int]:
545577
"""Get the raw array of signal pad pointers to all ranks (including self)"""
546-
return self.signal_pads_dev
578+
return self.signal_pads
547579

548-
def get_buffer_ptrs_dev(self) -> List[int]:
580+
def get_buffer_ptrs_host(self) -> List[int]:
549581
"""Get the raw array of unicast pointers to all ranks (including self)"""
550582
return self.uc_ptrs
551583

584+
def get_signal_pad_ptrs_dev(self) -> int:
585+
"""Get the raw array of signal pad pointers to all ranks (including self)"""
586+
return self.signal_pads_dev
587+
588+
def get_buffer_ptrs_dev(self) -> int:
589+
"""Get the raw array of unicast pointers to all ranks (including self)"""
590+
return self.uc_ptrs_dev
591+
552592
def get_unicast_ptr(self, rank: int) -> int:
553593
"""Get the raw unicast pointer to a given rank"""
554594
if rank >= len(self.uc_ptrs):
@@ -755,16 +795,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
755795
)
756796
)
757797

758-
def get_multicast_ptr_as_int64(self) -> int:
759-
"""Get multicast pointer as int64 (legacy compatibility)"""
760-
return self.get_multicast_ptr()
761-
762-
def get_buffer_ptrs_dev_as_int64(self) -> int:
763-
"""Get buffer pointers device as int64 (returning first UC pointer for now) (legacy compatibility)"""
764-
return self.uc_ptrs[0] if self.uc_ptrs else 0
765-
766798
def lamport_initialize(self, rank: int, dtype: torch.dtype):
767-
if dtype == torch.bfloat16:
799+
if dtype == torch.bfloat16 or dtype == torch.float16:
768800
neg_zero = 0x8000
769801
dsize = 2
770802
memset_func = cuda.cuMemsetD16
@@ -838,33 +870,6 @@ def get_multicast_ptr(self) -> int:
838870
"""Get the raw multicast pointer"""
839871
return self.mcast_device_memory.get_multicast_ptr()
840872

841-
def get_multicast_ptr_as_int64(self) -> int:
842-
"""Get the multicast pointer as int64"""
843-
return self.get_multicast_ptr()
844-
845-
def get_buffer_ptrs_dev(self) -> List[int]:
873+
def get_buffer_ptrs_dev(self) -> int:
846874
"""Get the buffer pointers device array"""
847875
return self.mcast_device_memory.get_buffer_ptrs_dev()
848-
849-
def get_buffer_ptrs_dev_as_int64(self) -> int:
850-
"""Get the buffer pointers device as int64 (returning first UC pointer)"""
851-
ptrs = self.get_buffer_ptrs_dev()
852-
assert ptrs is not None
853-
return ptrs[0] if ptrs else 0
854-
855-
def get_buffer_ptrs_dev_as_ctypes_ptr(self) -> int:
856-
"""
857-
Get buffer pointers as ctypes array pointer (equivalent to C++ void**).
858-
Returns the address of a ctypes array that can be cast to int64_t and back to void**.
859-
860-
This matches the C++ pattern:
861-
reinterpret_cast<int64_t>(reinterpret_cast<void**>(mUcPtrs.data()))
862-
"""
863-
# Create ctypes array of void pointers
864-
ArrayType = ctypes.c_void_p * len(self.mcast_device_memory.uc_ptrs)
865-
self._buffer_ptrs_array = ArrayType(
866-
*self.mcast_device_memory.uc_ptrs
867-
) # Keep reference to prevent GC
868-
869-
# Return the address of this array (equivalent to .data() in C++)
870-
return ctypes.cast(self._buffer_ptrs_array, ctypes.c_void_p).value

flashinfer/comm/trtllm_mnnvl_ar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def get_allreduce_mnnvl_workspace(
184184
mpi_barrier()
185185

186186
# This is a buffer to maintain the state of this allreduce Op
187-
# [Buffer_ptr, Clear_ptr, Buffer_size, atomic access counter]
187+
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
188188
buffer_flags = torch.tensor(
189-
[0, 2, max_num_elements, 0],
189+
[0, 2, max_num_elements, 0, 0],
190190
dtype=torch.uint32,
191191
device=torch.device("cuda", mapping.local_rank),
192192
)

0 commit comments

Comments
 (0)