Skip to content

Commit d5a36f2

Browse files
committed
add tp awareness for get_tensor
Signed-off-by: Xuchun Shang <[email protected]>
1 parent 7c02484 commit d5a36f2

File tree

2 files changed

+191
-85
lines changed

2 files changed

+191
-85
lines changed

mooncake-integration/store/store_py.cpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ class MooncakeStorePyWrapper {
109109
}
110110
}
111111

112-
pybind11::object get_tensor(const std::string &key) {
112+
pybind11::object get_tensor(const std::string &key, int tp_rank = 0,
113+
int tp_size = 1, int split_dim = 0) {
113114
if (!is_client_initialized()) {
114115
LOG(ERROR) << "Client is not initialized";
115116
return pybind11::none();
@@ -130,15 +131,27 @@ class MooncakeStorePyWrapper {
130131
}
131132
// Create contiguous buffer and copy data
132133
auto total_length = buffer_handle->size();
134+
// Validate data size
135+
if (total_length <= sizeof(TensorMetadata)) {
136+
py::gil_scoped_acquire acquire_gil;
137+
LOG(ERROR)
138+
<< "Invalid data format: insufficient data for metadata";
139+
return pybind11::none();
140+
}
141+
133142
char *exported_data = new char[total_length];
134143
if (!exported_data) {
135144
py::gil_scoped_acquire acquire_gil;
136-
LOG(ERROR) << "Invalid data format: insufficient data for "
137-
"metadata";
145+
LOG(ERROR) << "Failed to allocate memory for tensor data";
138146
return pybind11::none();
139147
}
140148
TensorMetadata metadata;
141149
// Copy data from buffer to contiguous memory
150+
// Note: Currently we copy the WHOLE buffer.
151+
// Optimization Opportunity: For split_dim=0, we could calculate
152+
// offsets and only copy the relevant slice to save Host Memory, but
153+
// that requires complex metadata manipulation. Here we use the
154+
// robust torch.chunk approach.
142155
memcpy(exported_data, buffer_handle->ptr(), total_length);
143156
memcpy(&metadata, exported_data, sizeof(TensorMetadata));
144157

@@ -174,6 +187,7 @@ class MooncakeStorePyWrapper {
174187
np_array = array_creators[dtype_index](
175188
exported_data, sizeof(TensorMetadata), tensor_size);
176189
} else {
190+
delete[] exported_data;
177191
LOG(ERROR) << "Unsupported dtype enum: " << dtype_index;
178192
return pybind11::none();
179193
}
@@ -186,8 +200,35 @@ class MooncakeStorePyWrapper {
186200
py::tuple shape_tuple = py::cast(shape_vec);
187201
np_array = np_array.attr("reshape")(shape_tuple);
188202
}
203+
// Get the full tensor first
189204
pybind11::object tensor =
190205
torch_module().attr("from_numpy")(np_array);
206+
207+
if (tp_size > 1) {
208+
if (split_dim < 0 || split_dim >= metadata.ndim) {
209+
LOG(ERROR) << "Invalid split_dim " << split_dim
210+
<< " for ndim " << metadata.ndim;
211+
return pybind11::none(); // Or return full tensor depending
212+
// on policy
213+
}
214+
215+
// Use torch.chunk to split the tensor
216+
py::object chunks = tensor.attr("chunk")(tp_size, split_dim);
217+
py::tuple chunks_tuple = chunks.cast<py::tuple>();
218+
219+
if (tp_rank < 0 ||
220+
tp_rank >= static_cast<int>(chunks_tuple.size())) {
221+
LOG(ERROR) << "Invalid tp_rank " << tp_rank
222+
<< " for tp_size " << tp_size;
223+
return pybind11::none();
224+
}
225+
226+
// Return only the slice for this rank
227+
// We call contiguous() to ensure the memory layout is clean for
228+
// subsequent GPU transfer
229+
return chunks_tuple[tp_rank].attr("contiguous")();
230+
}
231+
191232
return tensor;
192233

193234
} catch (const pybind11::error_already_set &e) {
@@ -726,8 +767,17 @@ PYBIND11_MODULE(store, m) {
726767
py::gil_scoped_release release;
727768
return self.store_->getSize(key);
728769
})
729-
.def("get_tensor", &MooncakeStorePyWrapper::get_tensor, py::arg("key"),
730-
"Get a PyTorch tensor from the store")
770+
.def(
771+
"get_tensor", &MooncakeStorePyWrapper::get_tensor, py::arg("key"),
772+
py::arg("tp_rank") = 0, py::arg("tp_size") = 1,
773+
py::arg("split_dim") = 0,
774+
"Get a PyTorch tensor from the store, optionally sliced for Tensor "
775+
"Parallelism.\n"
776+
"Args:\n"
777+
" key: The key of the tensor.\n"
778+
" tp_rank: The current tensor parallel rank (default 0).\n"
779+
" tp_size: The total tensor parallel size (default 1).\n"
780+
" split_dim: The dimension to split the tensor along (default 0).")
731781
.def("put_tensor", &MooncakeStorePyWrapper::put_tensor, py::arg("key"),
732782
py::arg("tensor"), "Put a PyTorch tensor into the store")
733783
.def("batch_get_tensor", &MooncakeStorePyWrapper::batch_get_tensor,

0 commit comments

Comments
 (0)