@@ -109,7 +109,8 @@ class MooncakeStorePyWrapper {
109109 }
110110 }
111111
112- pybind11::object get_tensor (const std::string &key) {
112+ pybind11::object get_tensor (const std::string &key, int tp_rank = 0 ,
113+ int tp_size = 1 , int split_dim = 0 ) {
113114 if (!is_client_initialized ()) {
114115 LOG (ERROR) << " Client is not initialized" ;
115116 return pybind11::none ();
@@ -130,15 +131,27 @@ class MooncakeStorePyWrapper {
130131 }
131132 // Create contiguous buffer and copy data
132133 auto total_length = buffer_handle->size ();
134+ // Validate data size
135+ if (total_length <= sizeof (TensorMetadata)) {
136+ py::gil_scoped_acquire acquire_gil;
137+ LOG (ERROR)
138+ << " Invalid data format: insufficient data for metadata" ;
139+ return pybind11::none ();
140+ }
141+
133142 char *exported_data = new char [total_length];
134143 if (!exported_data) {
135144 py::gil_scoped_acquire acquire_gil;
136- LOG (ERROR) << " Invalid data format: insufficient data for "
137- " metadata" ;
145+ LOG (ERROR) << " Failed to allocate memory for tensor data" ;
138146 return pybind11::none ();
139147 }
140148 TensorMetadata metadata;
141149 // Copy data from buffer to contiguous memory
150+ // Note: Currently we copy the WHOLE buffer.
151+ // Optimization Opportunity: For split_dim=0, we could calculate
152+ // offsets and only copy the relevant slice to save Host Memory, but
153+ // that requires complex metadata manipulation. Here we use the
154+ // robust torch.chunk approach.
142155 memcpy (exported_data, buffer_handle->ptr (), total_length);
143156 memcpy (&metadata, exported_data, sizeof (TensorMetadata));
144157
@@ -174,6 +187,7 @@ class MooncakeStorePyWrapper {
174187 np_array = array_creators[dtype_index](
175188 exported_data, sizeof (TensorMetadata), tensor_size);
176189 } else {
190+ delete[] exported_data;
177191 LOG (ERROR) << " Unsupported dtype enum: " << dtype_index;
178192 return pybind11::none ();
179193 }
@@ -186,8 +200,35 @@ class MooncakeStorePyWrapper {
186200 py::tuple shape_tuple = py::cast (shape_vec);
187201 np_array = np_array.attr (" reshape" )(shape_tuple);
188202 }
203+ // Get the full tensor first
189204 pybind11::object tensor =
190205 torch_module ().attr (" from_numpy" )(np_array);
206+
207+ if (tp_size > 1 ) {
208+ if (split_dim < 0 || split_dim >= metadata.ndim ) {
209+ LOG (ERROR) << " Invalid split_dim " << split_dim
210+ << " for ndim " << metadata.ndim ;
211+ return pybind11::none (); // Or return full tensor depending
212+ // on policy
213+ }
214+
215+ // Use torch.chunk to split the tensor
216+ py::object chunks = tensor.attr (" chunk" )(tp_size, split_dim);
217+ py::tuple chunks_tuple = chunks.cast <py::tuple>();
218+
219+ if (tp_rank < 0 ||
220+ tp_rank >= static_cast <int >(chunks_tuple.size ())) {
221+ LOG (ERROR) << " Invalid tp_rank " << tp_rank
222+ << " for tp_size " << tp_size;
223+ return pybind11::none ();
224+ }
225+
226+ // Return only the slice for this rank
227+ // We call contiguous() to ensure the memory layout is clean for
228+ // subsequent GPU transfer
229+ return chunks_tuple[tp_rank].attr (" contiguous" )();
230+ }
231+
191232 return tensor;
192233
193234 } catch (const pybind11::error_already_set &e) {
@@ -726,8 +767,17 @@ PYBIND11_MODULE(store, m) {
726767 py::gil_scoped_release release;
727768 return self.store_ ->getSize (key);
728769 })
729- .def (" get_tensor" , &MooncakeStorePyWrapper::get_tensor, py::arg (" key" ),
730- " Get a PyTorch tensor from the store" )
770+ .def (
771+ " get_tensor" , &MooncakeStorePyWrapper::get_tensor, py::arg (" key" ),
772+ py::arg (" tp_rank" ) = 0 , py::arg (" tp_size" ) = 1 ,
773+ py::arg (" split_dim" ) = 0 ,
774+ " Get a PyTorch tensor from the store, optionally sliced for Tensor "
775+ " Parallelism.\n "
776+ " Args:\n "
777+ " key: The key of the tensor.\n "
778+ " tp_rank: The current tensor parallel rank (default 0).\n "
779+ " tp_size: The total tensor parallel size (default 1).\n "
780+ " split_dim: The dimension to split the tensor along (default 0)." )
731781 .def (" put_tensor" , &MooncakeStorePyWrapper::put_tensor, py::arg (" key" ),
732782 py::arg (" tensor" ), " Put a PyTorch tensor into the store" )
733783 .def (" batch_get_tensor" , &MooncakeStorePyWrapper::batch_get_tensor,
0 commit comments