-
Notifications
You must be signed in to change notification settings - Fork 387
[TransferEngine]feat: add tensor transfer Read/Write API for transfer-engine #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -20,6 +20,8 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include <pybind11/stl.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto torch = py::module_::import("torch"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#ifdef USE_MNNVL | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#include "transport/nvlink_transport/nvlink_transport.h" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
static void *allocateMemory(size_t size) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -53,6 +55,72 @@ TransferEnginePy::~TransferEnginePy() { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
large_buffer_list_.clear(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
template <typename T> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::array TransferEnginePy::create_typed_array(char *exported_data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::capsule free_when_done( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
exported_data, [](void *p) { delete[] static_cast<char *>(p); }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return py::array_t<T>({static_cast<ssize_t>(total_length / sizeof(T))}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
(T *)(exported_data + offset), free_when_done); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
using ArrayCreatorFunc = std::function<py::array(char *, size_t, size_t)>; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
static const std::array<ArrayCreatorFunc, 11> array_creators = {{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<float>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // FLOAT32 = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<double>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // FLOAT64 = 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<int8_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // INT8 = 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<uint8_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // UINT8 = 3 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<int16_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // INT16 = 4 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<uint16_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // UINT16 = 5 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<int32_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // INT32 = 6 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<uint32_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // UINT32 = 7 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<int64_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // INT64 = 8 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<uint64_t>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, // UINT64 = 9 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](char* data, size_t offset, size_t total_length) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TransferEnginePy{}.create_typed_array<bool>(data, offset, total_length); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each lambda creates a temporary TransferEnginePy object to call create_typed_array. This is inefficient and unnecessary since create_typed_array could be made static or the lambdas could directly implement the array creation logic.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} // BOOL = 10 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorDtype TransferEnginePy::get_tensor_dtype(py::object dtype_obj) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.is_none()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TensorDtype::UNKNOWN; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("float32"))) return TensorDtype::FLOAT32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("float64"))) return TensorDtype::FLOAT64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("int8"))) return TensorDtype::INT8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("uint8"))) return TensorDtype::UINT8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("int16"))) return TensorDtype::INT16; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("uint16"))) return TensorDtype::UINT16; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("int32"))) return TensorDtype::INT32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("uint32"))) return TensorDtype::UINT32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("int64"))) return TensorDtype::INT64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("uint64"))) return TensorDtype::UINT64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_obj.equal(torch.attr("bool"))) return TensorDtype::BOOL; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return TensorDtype::UNKNOWN; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::vector<std::string> buildDeviceFilter(const std::string &device_names) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::stringstream ss(device_names); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::string item; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -474,6 +542,184 @@ batch_id_t TransferEnginePy::batchTransferAsync(const char *target_hostname, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return batch_id; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int TransferEnginePy::transferTensorSyncWrite(const char* target_hostname, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object tensor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uintptr_t peer_buffer_address) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Check whether it is pytorch tensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!(tensor.attr("__class__") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.attr("__name__") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.cast<std::string>() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.find("Tensor") != std::string::npos)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Input is not a PyTorch tensor"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Get tensor metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uintptr_t data_ptr = tensor.attr("data_ptr")().cast<uintptr_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t numel = tensor.attr("numel")().cast<size_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t element_size = tensor.attr("element_size")().cast<size_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t tensor_size = numel * element_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object shape_obj = tensor.attr("shape"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object dtype_obj = tensor.attr("dtype"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Build tensor metadata | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorDtype dtype_enum = get_tensor_dtype(dtype_obj); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_enum == TensorDtype::UNKNOWN) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Unsupported tensor dtype!"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Currently support tensors with no more than 4 dimensions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::tuple shape_tuple = pybind11::cast<pybind11::tuple>(shape_obj); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int32_t ndim = static_cast<int32_t>(shape_tuple.size()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (ndim > 4) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Tensor has more than 4 dimensions: " << ndim; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorMetadata metadata; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
metadata.dtype = static_cast<int32_t>(dtype_enum); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
metadata.ndim = ndim; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int i = 0; i < 4; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (i < ndim) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
metadata.shape[i] = shape_tuple[i].cast<int32_t>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
metadata.shape[i] = -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Using -1 as a sentinel value for unused shape dimensions is unclear. Consider using 0 or defining a named constant to make the intent more explicit.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Calculate total size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t total_size = sizeof(TensorMetadata) + tensor_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Allocate single buffer for metadata + tensor data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uintptr_t local_buffer = allocateManagedBuffer(total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (local_buffer == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Failed to allocate combined buffer"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Copy metadata to buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
memcpy(reinterpret_cast<void*>(local_buffer), &metadata, sizeof(TensorMetadata)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Copy tensor data to buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
memcpy(reinterpret_cast<void*>(local_buffer + sizeof(TensorMetadata)), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
reinterpret_cast<void*>(data_ptr), tensor_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Single transfer for the entire data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int ret = transferSync(target_hostname, local_buffer, peer_buffer_address, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
total_size, TransferOpcode::WRITE); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Clean up | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
freeManagedBuffer(local_buffer, total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return ret; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} catch (const pybind11::error_already_set &e) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Failed to access tensor data: " << e.what(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return -1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object TransferEnginePy::transferTensorSyncRead(const char* target_hostname, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uintptr_t peer_buffer_address, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t total_size) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Allocate single buffer for the entire data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
uintptr_t local_buffer = allocateManagedBuffer(total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (local_buffer == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Failed to allocate receive buffer, size: " << total_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Single transfer to read the entire data | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int ret = transferSync(target_hostname, local_buffer, peer_buffer_address, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
total_size, TransferOpcode::READ); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (ret != 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
freeManagedBuffer(local_buffer, total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Failed to transfer data, ret: " << ret; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Parse the metadata from the beginning of buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorMetadata metadata; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
memcpy(&metadata, reinterpret_cast<void*>(local_buffer), sizeof(TensorMetadata)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Add debug logging | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(INFO) << "Read metadata: dtype=" << metadata.dtype << ", ndim=" << metadata.ndim; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(INFO) << "Shape: [" << metadata.shape[0] << "," << metadata.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
<< "," << metadata.shape[2] << "," << metadata.shape[3] << "]"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (metadata.ndim < 0 || metadata.ndim > 4) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
freeManagedBuffer(local_buffer, total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Invalid tensor metadata: ndim=" << metadata.ndim; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TensorDtype dtype_enum = static_cast<TensorDtype>(metadata.dtype); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (metadata.dtype < 0 || metadata.dtype > 10 || dtype_enum == TensorDtype::UNKNOWN) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
freeManagedBuffer(local_buffer, total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Unknown tensor dtype: " << metadata.dtype; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Calculate actual tensor size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
size_t tensor_size = total_size - sizeof(TensorMetadata); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(INFO) << "Total size: " << total_size << ", metadata size: " << sizeof(TensorMetadata) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
<< ", tensor size: " << tensor_size; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Create contiguous memory copy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
char* exported_data = new char[total_size]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
memcpy(exported_data, reinterpret_cast<void*>(local_buffer), total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Release managed buffer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
freeManagedBuffer(local_buffer, total_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Create numpy array | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object np_array; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int dtype_index = static_cast<int>(dtype_enum); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (dtype_index >= 0 && dtype_index < static_cast<int>(array_creators.size())) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
np_array = array_creators[dtype_index](exported_data, sizeof(TensorMetadata), tensor_size); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
delete[] exported_data; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Unsupported dtype enum: " << dtype_index; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
// Reshape tensor data (only for ndim > 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (metadata.ndim > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::vector<int> shape_vec; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (int i = 0; i < metadata.ndim; i++) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (metadata.shape[i] > 0) { // Only add valid dimensions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
shape_vec.push_back(metadata.shape[i]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!shape_vec.empty()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::tuple shape_tuple = py::cast(shape_vec); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(INFO) << "Reshaping to " << shape_vec.size() << " dimensions"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
np_array = np_array.attr("reshape")(shape_tuple); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::object tensor = torch.attr("from_numpy")(np_array); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return tensor; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} catch (const pybind11::error_already_set &e) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LOG(ERROR) << "Failed to get tensor data: " << e.what(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::gil_scoped_acquire acquire_gil; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return pybind11::none(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
int TransferEnginePy::getBatchTransferStatus(const std::vector<batch_id_t>& batch_ids) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pybind11::gil_scoped_release release; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TransferStatus status; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -624,6 +870,20 @@ uintptr_t TransferEnginePy::getFirstBufferAddress( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
namespace py = pybind11; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
PYBIND11_MODULE(engine, m) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::enum_<TensorDtype>(m, "TensorDtype", py::arithmetic()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("FLOAT32", TensorDtype::FLOAT32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("FLOAT64", TensorDtype::FLOAT64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("INT8", TensorDtype::INT8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("UINT8", TensorDtype::UINT8) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("INT16", TensorDtype::INT16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("UINT16", TensorDtype::UINT16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("INT32", TensorDtype::INT32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("UINT32", TensorDtype::UINT32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("INT64", TensorDtype::INT64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("UINT64", TensorDtype::UINT64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("BOOL", TensorDtype::BOOL) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.value("UNKNOWN", TensorDtype::UNKNOWN) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.export_values(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
py::enum_<TransferEnginePy::TransferOpcode> transfer_opcode( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
m, "TransferOpcode", py::arithmetic()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
transfer_opcode.value("Read", TransferEnginePy::TransferOpcode::READ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -648,6 +908,8 @@ PYBIND11_MODULE(engine, m) { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("transfer_sync", &TransferEnginePy::transferSync) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("batch_transfer_sync", &TransferEnginePy::batchTransferSync) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("batch_transfer_async", &TransferEnginePy::batchTransferAsync) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("transfer_tensor_sync_write", &TransferEnginePy::transferTensorSyncWrite) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("transfer_tensor_sync_read", &TransferEnginePy::transferTensorSyncRead) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("get_batch_transfer_status", &TransferEnginePy::getBatchTransferStatus) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
.def("transfer_submit_write", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
&TransferEnginePy::transferSubmitWrite) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ | |||||||
#include <gflags/gflags.h> | ||||||||
#include <glog/logging.h> | ||||||||
#include <pybind11/pybind11.h> | ||||||||
#include <pybind11/numpy.h> | ||||||||
#include <sys/time.h> | ||||||||
|
||||||||
#include <cstdlib> | ||||||||
|
@@ -31,6 +32,7 @@ | |||||||
#include "transport/transport.h" | ||||||||
|
||||||||
using namespace mooncake; | ||||||||
namespace py = pybind11; | ||||||||
|
||||||||
const static size_t kDefaultBufferCapacity = 2ull * 1024 * 1024 * 1024; | ||||||||
const static size_t kSlabSizeKBTabLen = 16; | ||||||||
|
@@ -40,6 +42,17 @@ const static size_t kSlabSizeKB[] = { | |||||||
512, 1024, 2 * 1024, 4 * 1024, 8 * 1024, 16 * 1024, | ||||||||
32 * 1024, 64 * 1024, 128 * 1024, 256 * 1024}; | ||||||||
|
||||||||
enum class TensorDtype : int32_t { | ||||||||
FLOAT32 = 0, FLOAT64 = 1, INT8 = 2, UINT8 = 3, INT16 = 4, UINT16 = 5, | ||||||||
INT32 = 6, UINT32 = 7, INT64 = 8, UINT64 = 9, BOOL = 10, UNKNOWN = -1 | ||||||||
}; | ||||||||
|
||||||||
struct TensorMetadata { | ||||||||
int32_t dtype; | ||||||||
int32_t ndim; | ||||||||
int32_t shape[4]; | ||||||||
} __attribute__((packed)); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fixed-size shape array limits tensors to 4 dimensions. Consider using a more flexible approach or document this limitation clearly, as PyTorch tensors can have more than 4 dimensions in practice.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||
|
||||||||
class TransferEnginePy { | ||||||||
public: | ||||||||
enum class TransferOpcode { READ = 0, WRITE = 1 }; | ||||||||
|
@@ -114,6 +127,14 @@ class TransferEnginePy { | |||||||
const std::vector<uintptr_t> &peer_buffer_addresses, | ||||||||
const std::vector<size_t> &lengths, | ||||||||
TransferOpcode opcode); | ||||||||
|
||||||||
int transferTensorSyncWrite(const char* target_hostname, | ||||||||
pybind11::object tensor, | ||||||||
uintptr_t peer_buffer_address); | ||||||||
|
||||||||
pybind11::object transferTensorSyncRead(const char* target_hostname, | ||||||||
uintptr_t peer_buffer_address, | ||||||||
size_t length); | ||||||||
|
||||||||
int getBatchTransferStatus(const std::vector<batch_id_t> &batch_ids); | ||||||||
|
||||||||
|
@@ -142,13 +163,18 @@ class TransferEnginePy { | |||||||
|
||||||||
int batchUnregisterMemory(std::vector<uintptr_t> buffer_addresses); | ||||||||
|
||||||||
template<typename T> | ||||||||
py::array create_typed_array(char* data, size_t offset, size_t total_length); | ||||||||
|
||||||||
private: | ||||||||
char *allocateRawBuffer(size_t capacity); | ||||||||
|
||||||||
int findClassId(size_t size); | ||||||||
|
||||||||
int doBuddyAllocate(int class_id); | ||||||||
|
||||||||
TensorDtype get_tensor_dtype(py::object dtype_obj); | ||||||||
|
||||||||
private: | ||||||||
std::shared_ptr<TransferEngine> engine_; | ||||||||
Transport *xport_; | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global torch module import at file scope could cause issues if torch is not available when the module is loaded. Consider importing torch lazily within functions that need it, with proper error handling for cases where torch is not installed.
Copilot uses AI. Check for mistakes.