diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 1829aaf4d..dfe2822da 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -20,6 +20,8 @@ #include +auto torch = py::module_::import("torch"); + #ifdef USE_MNNVL #include "transport/nvlink_transport/nvlink_transport.h" static void *allocateMemory(size_t size) { @@ -53,6 +55,72 @@ TransferEnginePy::~TransferEnginePy() { large_buffer_list_.clear(); } +template +py::array TransferEnginePy::create_typed_array(char *exported_data, size_t offset, size_t total_length) { + py::capsule free_when_done( + exported_data, [](void *p) { delete[] static_cast(p); }); + return py::array_t({static_cast(total_length / sizeof(T))}, + (T *)(exported_data + offset), free_when_done); +} + +using ArrayCreatorFunc = std::function; + +static const std::array array_creators = {{ + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // FLOAT32 = 0 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // FLOAT64 = 1 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // INT8 = 2 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // UINT8 = 3 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // INT16 = 4 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // UINT16 = 5 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // INT32 = 6 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // UINT32 = 7 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // INT64 = 8 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + }, // UINT64 = 9 + [](char* data, size_t offset, size_t total_length) { + return TransferEnginePy{}.create_typed_array(data, offset, total_length); + } // BOOL = 10 +}}; + +TensorDtype TransferEnginePy::get_tensor_dtype(py::object dtype_obj) { + if (dtype_obj.is_none()) { + return TensorDtype::UNKNOWN; + } + + if (dtype_obj.equal(torch.attr("float32"))) return TensorDtype::FLOAT32; + if (dtype_obj.equal(torch.attr("float64"))) return TensorDtype::FLOAT64; + if (dtype_obj.equal(torch.attr("int8"))) return TensorDtype::INT8; + if (dtype_obj.equal(torch.attr("uint8"))) return TensorDtype::UINT8; + if (dtype_obj.equal(torch.attr("int16"))) return TensorDtype::INT16; + if (dtype_obj.equal(torch.attr("uint16"))) return TensorDtype::UINT16; + if (dtype_obj.equal(torch.attr("int32"))) return TensorDtype::INT32; + if (dtype_obj.equal(torch.attr("uint32"))) return TensorDtype::UINT32; + if (dtype_obj.equal(torch.attr("int64"))) return TensorDtype::INT64; + if (dtype_obj.equal(torch.attr("uint64"))) return TensorDtype::UINT64; + if (dtype_obj.equal(torch.attr("bool"))) return TensorDtype::BOOL; + + return TensorDtype::UNKNOWN; +} + std::vector buildDeviceFilter(const std::string &device_names) { std::stringstream ss(device_names); std::string item; @@ -474,6 +542,184 @@ batch_id_t TransferEnginePy::batchTransferAsync(const char *target_hostname, return batch_id; } +int TransferEnginePy::transferTensorSyncWrite(const char* target_hostname, + pybind11::object tensor, + uintptr_t peer_buffer_address) { + try { + // Check whether it is pytorch tensor + if (!(tensor.attr("__class__") + .attr("__name__") + .cast() + .find("Tensor") != std::string::npos)) { + LOG(ERROR) << "Input is not a PyTorch tensor"; + return -1; + } + + // Get tensor metadata + uintptr_t data_ptr = tensor.attr("data_ptr")().cast(); + size_t numel = tensor.attr("numel")().cast(); + size_t element_size = tensor.attr("element_size")().cast(); + size_t tensor_size = numel * element_size; + + pybind11::object shape_obj = tensor.attr("shape"); + pybind11::object dtype_obj = tensor.attr("dtype"); + + // Build tensor metadata + TensorDtype dtype_enum = get_tensor_dtype(dtype_obj); + if (dtype_enum == TensorDtype::UNKNOWN) { + LOG(ERROR) << "Unsupported tensor dtype!"; + return -1; + } + + // Currently support tensors with no more than 4 dimensions + pybind11::tuple shape_tuple = pybind11::cast(shape_obj); + int32_t ndim = static_cast(shape_tuple.size()); + if (ndim > 4) { + LOG(ERROR) << "Tensor has more than 4 dimensions: " << ndim; + return -1; + } + + TensorMetadata metadata; + metadata.dtype = static_cast(dtype_enum); + metadata.ndim = ndim; + + for (int i = 0; i < 4; i++) { + if (i < ndim) { + metadata.shape[i] = shape_tuple[i].cast(); + } else { + metadata.shape[i] = -1; + } + } + + // Calculate total size + size_t total_size = sizeof(TensorMetadata) + tensor_size; + + // Allocate single buffer for metadata + tensor data + uintptr_t local_buffer = allocateManagedBuffer(total_size); + if (local_buffer == 0) { + LOG(ERROR) << "Failed to allocate combined buffer"; + return -1; + } + + // Copy metadata to buffer + memcpy(reinterpret_cast(local_buffer), &metadata, sizeof(TensorMetadata)); + + // Copy tensor data to buffer + memcpy(reinterpret_cast(local_buffer + sizeof(TensorMetadata)), + reinterpret_cast(data_ptr), tensor_size); + + // Single transfer for the entire data + int ret = transferSync(target_hostname, local_buffer, peer_buffer_address, + total_size, TransferOpcode::WRITE); + + // Clean up + freeManagedBuffer(local_buffer, total_size); + + return ret; + + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed to access tensor data: " << e.what(); + return -1; + } +} + +pybind11::object TransferEnginePy::transferTensorSyncRead(const char* target_hostname, + uintptr_t peer_buffer_address, + size_t total_size) { + try { + // Allocate single buffer for the entire data + uintptr_t local_buffer = allocateManagedBuffer(total_size); + if (local_buffer == 0) { + LOG(ERROR) << "Failed to allocate receive buffer, size: " << total_size; + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + // Single transfer to read the entire data + int ret = transferSync(target_hostname, local_buffer, peer_buffer_address, + total_size, TransferOpcode::READ); + if (ret != 0) { + freeManagedBuffer(local_buffer, total_size); + LOG(ERROR) << "Failed to transfer data, ret: " << ret; + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + // Parse the metadata from the beginning of buffer + TensorMetadata metadata; + memcpy(&metadata, reinterpret_cast(local_buffer), sizeof(TensorMetadata)); + + // Add debug logging + LOG(INFO) << "Read metadata: dtype=" << metadata.dtype << ", ndim=" << metadata.ndim; + LOG(INFO) << "Shape: [" << metadata.shape[0] << "," << metadata.shape[1] + << "," << metadata.shape[2] << "," << metadata.shape[3] << "]"; + + if (metadata.ndim < 0 || metadata.ndim > 4) { + freeManagedBuffer(local_buffer, total_size); + LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim; + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + TensorDtype dtype_enum = static_cast(metadata.dtype); + if (metadata.dtype < 0 || metadata.dtype > 10 || dtype_enum == TensorDtype::UNKNOWN) { + freeManagedBuffer(local_buffer, total_size); + LOG(ERROR) << "Unknown tensor dtype: " << metadata.dtype; + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + // Calculate actual tensor size + size_t tensor_size = total_size - sizeof(TensorMetadata); + LOG(INFO) << "Total size: " << total_size << ", metadata size: " << sizeof(TensorMetadata) + << ", tensor size: " << tensor_size; + + // Create contiguous memory copy + char* exported_data = new char[total_size]; + memcpy(exported_data, reinterpret_cast(local_buffer), total_size); + + // Release managed buffer + freeManagedBuffer(local_buffer, total_size); + + // Create numpy array + pybind11::object np_array; + int dtype_index = static_cast(dtype_enum); + if (dtype_index >= 0 && dtype_index < static_cast(array_creators.size())) { + np_array = array_creators[dtype_index](exported_data, sizeof(TensorMetadata), tensor_size); + } else { + delete[] exported_data; + LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } + + // Reshape tensor data (only for ndim > 0) + if (metadata.ndim > 0) { + std::vector shape_vec; + for (int i = 0; i < metadata.ndim; i++) { + if (metadata.shape[i] > 0) { // Only add valid dimensions + shape_vec.push_back(metadata.shape[i]); + } + } + + if (!shape_vec.empty()) { + py::tuple shape_tuple = py::cast(shape_vec); + LOG(INFO) << "Reshaping to " << shape_vec.size() << " dimensions"; + np_array = np_array.attr("reshape")(shape_tuple); + } + } + + py::gil_scoped_acquire acquire_gil; + pybind11::object tensor = torch.attr("from_numpy")(np_array); + return tensor; + + } catch (const pybind11::error_already_set &e) { + LOG(ERROR) << "Failed to get tensor data: " << e.what(); + py::gil_scoped_acquire acquire_gil; + return pybind11::none(); + } +} + int TransferEnginePy::getBatchTransferStatus(const std::vector& batch_ids) { pybind11::gil_scoped_release release; TransferStatus status; @@ -624,6 +870,20 @@ uintptr_t TransferEnginePy::getFirstBufferAddress( namespace py = pybind11; PYBIND11_MODULE(engine, m) { + py::enum_(m, "TensorDtype", py::arithmetic()) + .value("FLOAT32", TensorDtype::FLOAT32) + .value("FLOAT64", TensorDtype::FLOAT64) + .value("INT8", TensorDtype::INT8) + .value("UINT8", TensorDtype::UINT8) + .value("INT16", TensorDtype::INT16) + .value("UINT16", TensorDtype::UINT16) + .value("INT32", TensorDtype::INT32) + .value("UINT32", TensorDtype::UINT32) + .value("INT64", TensorDtype::INT64) + .value("UINT64", TensorDtype::UINT64) + .value("BOOL", TensorDtype::BOOL) + .value("UNKNOWN", TensorDtype::UNKNOWN) + .export_values(); py::enum_ transfer_opcode( m, "TransferOpcode", py::arithmetic()); transfer_opcode.value("Read", TransferEnginePy::TransferOpcode::READ) @@ -648,6 +908,8 @@ PYBIND11_MODULE(engine, m) { .def("transfer_sync", &TransferEnginePy::transferSync) .def("batch_transfer_sync", &TransferEnginePy::batchTransferSync) .def("batch_transfer_async", &TransferEnginePy::batchTransferAsync) + .def("transfer_tensor_sync_write", &TransferEnginePy::transferTensorSyncWrite) + .def("transfer_tensor_sync_read", &TransferEnginePy::transferTensorSyncRead) .def("get_batch_transfer_status", &TransferEnginePy::getBatchTransferStatus) .def("transfer_submit_write", &TransferEnginePy::transferSubmitWrite) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.h b/mooncake-integration/transfer_engine/transfer_engine_py.h index 558b59ffa..8a691b058 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.h +++ b/mooncake-integration/transfer_engine/transfer_engine_py.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -31,6 +32,7 @@ #include "transport/transport.h" using namespace mooncake; +namespace py = pybind11; const static size_t kDefaultBufferCapacity = 2ull * 1024 * 1024 * 1024; const static size_t kSlabSizeKBTabLen = 16; @@ -40,6 +42,17 @@ const static size_t kSlabSizeKB[] = { 512, 1024, 2 * 1024, 4 * 1024, 8 * 1024, 16 * 1024, 32 * 1024, 64 * 1024, 128 * 1024, 256 * 1024}; +enum class TensorDtype : int32_t { + FLOAT32 = 0, FLOAT64 = 1, INT8 = 2, UINT8 = 3, INT16 = 4, UINT16 = 5, + INT32 = 6, UINT32 = 7, INT64 = 8, UINT64 = 9, BOOL = 10, UNKNOWN = -1 +}; + +struct TensorMetadata { + int32_t dtype; + int32_t ndim; + int32_t shape[4]; +} __attribute__((packed)); + class TransferEnginePy { public: enum class TransferOpcode { READ = 0, WRITE = 1 }; @@ -114,6 +127,14 @@ class TransferEnginePy { const std::vector &peer_buffer_addresses, const std::vector &lengths, TransferOpcode opcode); + + int transferTensorSyncWrite(const char* target_hostname, + pybind11::object tensor, + uintptr_t peer_buffer_address); + + pybind11::object transferTensorSyncRead(const char* target_hostname, + uintptr_t peer_buffer_address, + size_t length); int getBatchTransferStatus(const std::vector &batch_ids); @@ -142,6 +163,9 @@ class TransferEnginePy { int batchUnregisterMemory(std::vector buffer_addresses); + template + py::array create_typed_array(char* data, size_t offset, size_t total_length); + private: char *allocateRawBuffer(size_t capacity); @@ -149,6 +173,8 @@ class TransferEnginePy { int doBuddyAllocate(int class_id); + TensorDtype get_tensor_dtype(py::object dtype_obj); + private: std::shared_ptr engine_; Transport *xport_; diff --git a/mooncake-wheel/tests/test_transfer_tensor.py b/mooncake-wheel/tests/test_transfer_tensor.py new file mode 100644 index 000000000..cd8b10ef4 --- /dev/null +++ b/mooncake-wheel/tests/test_transfer_tensor.py @@ -0,0 +1,219 @@ +import ctypes +import struct +import unittest +import os +import time +import threading +import random +import torch +from mooncake import engine + + +def get_transfer_engine(): + """Initialize and setup the transfer engine.""" + protocol = os.getenv("PROTOCOL", "tcp") + device_name = os.getenv("DEVICE_NAME", "") + local_hostname = os.getenv("LOCAL_HOSTNAME", "localhost") + metadata_server = os.getenv("MC_METADATA_SERVER", "P2PHANDSHAKE") + + te = engine.TransferEngine() + retcode = te.initialize(local_hostname, metadata_server, protocol, device_name) + + if retcode: + raise RuntimeError(f"Failed to initialize transfer engine. Return code: {retcode}") + + return te + + +class TestTransferEngineTensor(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Initialize the transfer engines for both sender and receiver.""" + cls.sender_engine = get_transfer_engine() + cls.receiver_engine = get_transfer_engine() + + sender_port = cls.sender_engine.get_rpc_port() + receiver_port = cls.receiver_engine.get_rpc_port() + + cls.sender_hostname = f"localhost:{sender_port}" + cls.receiver_hostname = f"localhost:{receiver_port}" + + print(f"Sender engine on: {cls.sender_hostname}") + print(f"Receiver engine on: {cls.receiver_hostname}") + + time.sleep(1) + + def test_transfer_tensor_float32(self): + """Test transferring float32 tensor.""" + # Create a float32 tensor on sender side + tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32) + + # Calculate total size (metadata + tensor data) + metadata_size = 24 # sizeof(TensorMetadata) = 4 * 4 bytes + tensor_size = tensor.numel() * tensor.element_size() + total_size = metadata_size + tensor_size + + # Allocate buffer on receiver side + receiver_buffer = self.receiver_engine.allocate_managed_buffer(total_size) + self.assertNotEqual(receiver_buffer, 0) + + try: + # Transfer tensor from sender to receiver + result = self.sender_engine.transfer_tensor_sync_write( + self.receiver_hostname, tensor, receiver_buffer) + self.assertEqual(result, 0) + + # Read tensor back on sender side (simulating cross-node retrieval) + retrieved_tensor = self.sender_engine.transfer_tensor_sync_read( + self.receiver_hostname, receiver_buffer, total_size) + + # Verify the retrieved tensor + self.assertIsNotNone(retrieved_tensor) + self.assertEqual(retrieved_tensor.dtype, tensor.dtype) + self.assertEqual(tuple(retrieved_tensor.shape), tuple(tensor.shape)) + self.assertTrue(torch.allclose(tensor, retrieved_tensor)) + + finally: + # Clean up + self.receiver_engine.free_managed_buffer(receiver_buffer, total_size) + + def test_transfer_tensor_int32(self): + """Test transferring int32 tensor.""" + # Create an int32 tensor + tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int32) + + # Calculate total size + metadata_size = 24 + tensor_size = tensor.numel() * tensor.element_size() + total_size = metadata_size + tensor_size + + # Allocate buffer on receiver side + receiver_buffer = self.receiver_engine.allocate_managed_buffer(total_size) + self.assertNotEqual(receiver_buffer, 0) + + try: + # Transfer tensor + result = self.sender_engine.transfer_tensor_sync_write( + self.receiver_hostname, tensor, receiver_buffer) + self.assertEqual(result, 0) + + # Read tensor back + retrieved_tensor = self.sender_engine.transfer_tensor_sync_read( + self.receiver_hostname, receiver_buffer, total_size) + + # Verify + self.assertIsNotNone(retrieved_tensor) + self.assertEqual(retrieved_tensor.dtype, tensor.dtype) + self.assertEqual(tuple(retrieved_tensor.shape), tuple(tensor.shape)) + self.assertTrue(torch.equal(tensor, retrieved_tensor)) + + finally: + self.receiver_engine.free_managed_buffer(receiver_buffer, total_size) + + def test_transfer_tensor_bool(self): + """Test transferring bool tensor.""" + # Create a bool tensor + tensor = torch.tensor([True, False, False, True], dtype=torch.bool) + + # Calculate total size + metadata_size = 24 + tensor_size = tensor.numel() * tensor.element_size() + total_size = metadata_size + tensor_size + + # Allocate buffer on receiver side + receiver_buffer = self.receiver_engine.allocate_managed_buffer(total_size) + self.assertNotEqual(receiver_buffer, 0) + + try: + # Transfer tensor + result = self.sender_engine.transfer_tensor_sync_write( + self.receiver_hostname, tensor, receiver_buffer) + self.assertEqual(result, 0) + + # Read tensor back + retrieved_tensor = self.sender_engine.transfer_tensor_sync_read( + self.receiver_hostname, receiver_buffer, total_size) + + # Verify + self.assertIsNotNone(retrieved_tensor) + self.assertEqual(retrieved_tensor.dtype, tensor.dtype) + self.assertEqual(tuple(retrieved_tensor.shape), tuple(tensor.shape)) + self.assertTrue(torch.equal(tensor, retrieved_tensor)) + + finally: + self.receiver_engine.free_managed_buffer(receiver_buffer, total_size) + + def test_transfer_tensor_large(self): + """Test transferring large tensor.""" + # Create a larger tensor + tensor = torch.randn(1000, dtype=torch.float32) + + # Calculate total size + metadata_size = 24 + tensor_size = tensor.numel() * tensor.element_size() + total_size = metadata_size + tensor_size + + # Allocate buffer on receiver side + receiver_buffer = self.receiver_engine.allocate_managed_buffer(total_size) + self.assertNotEqual(receiver_buffer, 0) + + try: + # Transfer tensor + result = self.sender_engine.transfer_tensor_sync_write( + self.receiver_hostname, tensor, receiver_buffer) + self.assertEqual(result, 0) + + # Read tensor back + retrieved_tensor = self.sender_engine.transfer_tensor_sync_read( + self.receiver_hostname, receiver_buffer, total_size) + + # Verify + self.assertIsNotNone(retrieved_tensor) + self.assertEqual(retrieved_tensor.dtype, tensor.dtype) + self.assertEqual(tuple(retrieved_tensor.shape), tuple(tensor.shape)) + self.assertTrue(torch.allclose(tensor, retrieved_tensor)) + + finally: + self.receiver_engine.free_managed_buffer(receiver_buffer, total_size) + + def test_transfer_tensor_multidimensional(self): + """Test transferring multi-dimensional tensor.""" + # Create a 2D tensor + tensor = torch.randn(3, 4, dtype=torch.float32) + + # Calculate total size + metadata_size = 24 + tensor_size = tensor.numel() * tensor.element_size() + total_size = metadata_size + tensor_size + + # Allocate buffer on receiver side + receiver_buffer = self.receiver_engine.allocate_managed_buffer(total_size) + self.assertNotEqual(receiver_buffer, 0) + + try: + # Transfer tensor + result = self.sender_engine.transfer_tensor_sync_write( + self.receiver_hostname, tensor, receiver_buffer) + self.assertEqual(result, 0) + + # Read tensor back + retrieved_tensor = self.sender_engine.transfer_tensor_sync_read( + self.receiver_hostname, receiver_buffer, total_size) + + # Verify + self.assertIsNotNone(retrieved_tensor) + self.assertEqual(retrieved_tensor.dtype, tensor.dtype) + self.assertEqual(tuple(retrieved_tensor.shape), tuple(tensor.shape)) + self.assertTrue(torch.allclose(tensor, retrieved_tensor)) + + finally: + self.receiver_engine.free_managed_buffer(receiver_buffer, total_size) + + @classmethod + def tearDownClass(cls): + """Clean up transfer engines.""" + pass + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file