Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,15 @@ int TransferEnginePy::transferCheckStatus(batch_id_t batch_id) {
}

int TransferEnginePy::batchRegisterMemory(std::vector<uintptr_t> buffer_addresses,
std::vector<size_t> capacities) {
std::vector<size_t> capacities,
const std::string &location) {
pybind11::gil_scoped_release release;
auto batch_size = buffer_addresses.size();
std::vector<BufferEntry> buffers;
for (size_t i = 0; i < batch_size; i ++ ) {
buffers.push_back(BufferEntry{(void *)buffer_addresses[i], capacities[i]});
}
return engine_->registerLocalMemoryBatch(buffers, kWildcardLocation);
return engine_->registerLocalMemoryBatch(buffers, location);
}

int TransferEnginePy::batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses) {
Expand All @@ -603,9 +604,9 @@ int TransferEnginePy::batchUnregisterMemory(std::vector<uintptr_t> buffer_addres
return engine_->unregisterLocalMemoryBatch(buffers);
}

int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity) {
int TransferEnginePy::registerMemory(uintptr_t buffer_addr, size_t capacity, const std::string &location) {
char *buffer = reinterpret_cast<char *>(buffer_addr);
return engine_->registerLocalMemory(buffer, capacity);
return engine_->registerLocalMemory(buffer, capacity, location);
}

int TransferEnginePy::unregisterMemory(uintptr_t buffer_addr) {
Expand Down Expand Up @@ -656,9 +657,19 @@ PYBIND11_MODULE(engine, m) {
.def("write_bytes_to_buffer", &TransferEnginePy::writeBytesToBuffer)
.def("read_bytes_from_buffer",
&TransferEnginePy::readBytesFromBuffer)
.def("register_memory", &TransferEnginePy::registerMemory)
.def("register_memory",
&TransferEnginePy::registerMemory,
py::arg("buffer_addr"),
py::arg("capacity"),
py::arg("location") = kWildcardLocation
)
.def("unregister_memory", &TransferEnginePy::unregisterMemory)
.def("batch_register_memory", &TransferEnginePy::batchRegisterMemory)
.def("batch_register_memory",
&TransferEnginePy::batchRegisterMemory,
py::arg("buffer_addresses"),
py::arg("capacities"),
py::arg("location") = kWildcardLocation
)
.def("batch_unregister_memory", &TransferEnginePy::batchUnregisterMemory)
.def("get_first_buffer_address",
&TransferEnginePy::getFirstBufferAddress);
Expand Down
4 changes: 2 additions & 2 deletions mooncake-integration/transfer_engine/transfer_engine_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,12 @@ class TransferEnginePy {
}

// FOR EXPERIMENT ONLY
int registerMemory(uintptr_t buffer_addr, size_t capacity);
int registerMemory(uintptr_t buffer_addr, size_t capacity, const std::string &location = kWildcardLocation);

// must be called before TransferEnginePy::~TransferEnginePy()
int unregisterMemory(uintptr_t buffer_addr);

int batchRegisterMemory(std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities);
int batchRegisterMemory(std::vector<uintptr_t> buffer_addresses, std::vector<size_t> capacities, const std::string &location = kWildcardLocation);

int batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses);

Expand Down
81 changes: 66 additions & 15 deletions mooncake-wheel/tests/transfer_engine_initiator_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import unittest
import ctypes
import os
import random
import string
from mooncake.engine import TransferEngine


def generate_random_string(length):
chars = string.ascii_letters + string.digits + string.punctuation
return ''.join(random.choices(chars, k=length))


class TestVLLMAdaptorTransfer(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -24,11 +32,6 @@ def setUpClass(cls):

def test_random_write_circle_times(self):
"""Test circle times of random string write/read via buffer transfer."""
import random, string

def generate_random_string(length):
chars = string.ascii_letters + string.digits + string.punctuation
return ''.join(random.choices(chars, k=length))

adaptor = self.adaptor
circles = self.circle
Expand Down Expand Up @@ -70,11 +73,6 @@ def generate_random_string(length):

def test_batch_write_read(self):
"""Test batch_transfer_sync_write and batch_transfer_sync_read for batch write/read consistency."""
import random, string

def generate_random_string(length):
chars = string.ascii_letters + string.digits + string.punctuation
return ''.join(random.choices(chars, k=length))

adaptor = self.adaptor
batch_size = 100 # Adjust batch size if needed
Expand Down Expand Up @@ -133,11 +131,6 @@ def generate_random_string(length):

def test_async_batch_write_read(self):
"""Test batch_transfer_async_write and batch_transfer_async_read for batch write/read consistency."""
import random, string

def generate_random_string(length):
chars = string.ascii_letters + string.digits + string.punctuation
return ''.join(random.choices(chars, k=length))

adaptor = self.adaptor
batch_size = 100 # Adjust batch size if needed
Expand Down Expand Up @@ -200,5 +193,63 @@ def generate_random_string(length):

print(f"[✓] {circles} rounds of batch_write_async_read passed, batch size {batch_size}.")

def run_test_register_memory(self, dst_addr, with_location):
adaptor = self.adaptor
circles = self.circle
buffer_size = 10 * 1024
buffer = ctypes.create_string_buffer(buffer_size)
buffer_addr = ctypes.addressof(buffer)

if with_location:
adaptor.register_memory(buffer_addr, buffer_size, "cpu")
else:
adaptor.register_memory(buffer_addr, buffer_size)

try:
for i in range(circles):
str_len = random.randint(16, 256)
src_data = generate_random_string(str_len).encode('utf-8')
data_len = len(src_data)
offset = random.randint(0, 1024)
assert offset + data_len <= buffer_size
buffer[offset:offset + data_len] = src_data

#Write to the remote end
result = adaptor.transfer_sync_write(
self.target_server_name, buffer_addr + offset, dst_addr, data_len
)
self.assertEqual(result, 0, f"[{i}] WRITE transferSyncExt failed")

#Clear the local buffer
clear_data = bytes([0] * data_len)
buffer[offset:offset + data_len] = clear_data

#Read it back from the remote end
dst_offset = random.randint(0, 1024)
assert dst_offset + data_len <= buffer_size

result = adaptor.transfer_sync_read(
self.target_server_name, buffer_addr + dst_offset, dst_addr, data_len
)
self.assertEqual(result, 0, f"[{i}] READ transferSyncExt failed")

#Verify data consistency
read_back = bytes(buffer[dst_offset:dst_offset + data_len])
self.assertEqual(read_back, src_data, f"[{i}] Data mismatch")

#Clear the local buffer
buffer[dst_offset:dst_offset + data_len] = clear_data
print(f"[✓] {circles} iterations of random write-read with custom buffer passed successfully ({with_location=}).")
finally:
adaptor.unregister_memory(buffer_addr)

def test_register_memory(self):
adaptor = self.adaptor
dst_addr = adaptor.get_first_buffer_address(self.target_server_name)

for with_location in [False, True]:
self.run_test_register_memory(dst_addr, with_location)


if __name__ == '__main__':
unittest.main()
Loading