Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/trtllm_mnnvl_allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ void trtllm_mnnvl_all_reduce(at::Tensor& in, int64_t multicast_buffer_ptr, int64
int64_t token_dim = in.size(1);

// Validate input parameters
TORCH_CHECK(token_dim % (sizeof(float2) / sizeof(c_type)) == 0,
"token_dim must be divisible by ", sizeof(float2) / sizeof(c_type));
TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks);
TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank);
TORCH_CHECK(out.has_value() || !wait_for_results,
Expand Down
95 changes: 50 additions & 45 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import platform
import sys
from dataclasses import dataclass
from typing import List
from typing import List, Optional

import pynvml
import torch
Expand Down Expand Up @@ -110,6 +110,26 @@ def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool:
return False


def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
"""
A helper function that allocates memory on cuda and copies the data from the host to the device.
"""
if not host_ptr_array:
return None

ArrayType = ctypes.c_uint64 * len(host_ptr_array)
c_array = ArrayType(*host_ptr_array)
size_in_bytes = ctypes.sizeof(c_array)

device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
checkCudaErrors(
cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)
)
# c_array should be freed by GC

return device_ptr


class MpiComm:
_comm: MPI.Intracomm = MPI.COMM_WORLD

Expand Down Expand Up @@ -423,7 +443,9 @@ def __init__(
# CUDA memory handles and pointers
self.mc_ptr = 0 # CUdeviceptr mMcPtr
self.uc_ptrs: List[int] = [] # std::vector<CUdeviceptr> mUcPtrs
self.signal_pads_dev: List[int] = [] # std::vector<CUdeviceptr> mSignalPadsDev
self.signal_pads: List[int] = [] # mSignalPads
self.signal_pads_dev = 0 # std::vector<CUdeviceptr> mSignalPadsDev
self.uc_ptrs_dev = 0
self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle
self.uc_handles: List[int] = (
[]
Expand Down Expand Up @@ -475,14 +497,18 @@ def __init__(
raise NotImplementedError("Single-node NVLS allocation not implemented yet")

# Initialize signal pads
self.signal_pads_dev = [0] * self.group_size
self.signal_pads = [0] * self.group_size
for i in range(self.group_size):
self.signal_pads_dev[i] = self.uc_ptrs[i] + self.signal_pad_offset
self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset
if i == self.group_rank:
checkCudaErrors(
cuda.cuMemsetD8(self.signal_pads_dev[i], 0, self.SIGNAL_PAD_SIZE)
cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)
)

# Create device pointers
self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads)
self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs)

def __del__(self):
"""Destructor - cleanup allocated memory"""

Expand All @@ -505,6 +531,12 @@ def __del__(self):
print(f"Destructor: CUDA context invalid, skipping cleanup: {e}")
return

# Free device pointers
if self.signal_pads_dev:
checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev))
if self.uc_ptrs_dev:
checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev))

# Unmap UC regions and release their handles
if hasattr(self, "uc_handles") and self.uc_handles:
for rank in range(self.group_size):
Expand Down Expand Up @@ -541,14 +573,22 @@ def __del__(self):
except Exception as e:
print(f"Destructor: Failed to release MC handle: {e}")

def get_signal_pad_ptrs_dev(self) -> List[int]:
def get_signal_pad_ptrs_host(self) -> List[int]:
"""Get the raw array of signal pad pointers to all ranks (including self)"""
return self.signal_pads_dev
return self.signal_pads

def get_buffer_ptrs_dev(self) -> List[int]:
def get_buffer_ptrs_host(self) -> List[int]:
"""Get the raw array of unicast pointers to all ranks (including self)"""
return self.uc_ptrs

def get_signal_pad_ptrs_dev(self) -> int:
"""Get the raw array of signal pad pointers to all ranks (including self)"""
return self.signal_pads_dev

def get_buffer_ptrs_dev(self) -> int:
"""Get the raw array of unicast pointers to all ranks (including self)"""
return self.uc_ptrs_dev

def get_unicast_ptr(self, rank: int) -> int:
"""Get the raw unicast pointer to a given rank"""
if rank >= len(self.uc_ptrs):
Expand Down Expand Up @@ -755,16 +795,8 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
)
)

def get_multicast_ptr_as_int64(self) -> int:
"""Get multicast pointer as int64 (legacy compatibility)"""
return self.get_multicast_ptr()

def get_buffer_ptrs_dev_as_int64(self) -> int:
"""Get buffer pointers device as int64 (returning first UC pointer for now) (legacy compatibility)"""
return self.uc_ptrs[0] if self.uc_ptrs else 0

def lamport_initialize(self, rank: int, dtype: torch.dtype):
if dtype == torch.bfloat16:
if dtype == torch.bfloat16 or dtype == torch.float16:
neg_zero = 0x8000
dsize = 2
memset_func = cuda.cuMemsetD16
Expand Down Expand Up @@ -838,33 +870,6 @@ def get_multicast_ptr(self) -> int:
"""Get the raw multicast pointer"""
return self.mcast_device_memory.get_multicast_ptr()

def get_multicast_ptr_as_int64(self) -> int:
"""Get the multicast pointer as int64"""
return self.get_multicast_ptr()

def get_buffer_ptrs_dev(self) -> List[int]:
def get_buffer_ptrs_dev(self) -> int:
"""Get the buffer pointers device array"""
return self.mcast_device_memory.get_buffer_ptrs_dev()

def get_buffer_ptrs_dev_as_int64(self) -> int:
"""Get the buffer pointers device as int64 (returning first UC pointer)"""
ptrs = self.get_buffer_ptrs_dev()
assert ptrs is not None
return ptrs[0] if ptrs else 0

def get_buffer_ptrs_dev_as_ctypes_ptr(self) -> int:
"""
Get buffer pointers as ctypes array pointer (equivalent to C++ void**).
Returns the address of a ctypes array that can be cast to int64_t and back to void**.

This matches the C++ pattern:
reinterpret_cast<int64_t>(reinterpret_cast<void**>(mUcPtrs.data()))
"""
# Create ctypes array of void pointers
ArrayType = ctypes.c_void_p * len(self.mcast_device_memory.uc_ptrs)
self._buffer_ptrs_array = ArrayType(
*self.mcast_device_memory.uc_ptrs
) # Keep reference to prevent GC

# Return the address of this array (equivalent to .data() in C++)
return ctypes.cast(self._buffer_ptrs_array, ctypes.c_void_p).value
4 changes: 2 additions & 2 deletions flashinfer/comm/trtllm_mnnvl_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def get_allreduce_mnnvl_workspace(
mpi_barrier()

# This is a buffer to maintain the state of this allreduce Op
# [Buffer_ptr, Clear_ptr, Buffer_size, atomic access counter]
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
buffer_flags = torch.tensor(
[0, 2, max_num_elements, 0],
[0, 2, max_num_elements, 0, 0],
dtype=torch.uint32,
device=torch.device("cuda", mapping.local_rank),
)
Expand Down
Loading