-
Notifications
You must be signed in to change notification settings - Fork 454
[Store] add tp awareness for get_tensor #1127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @XucSh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Mooncake store's Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Signed-off-by: Xuchun Shang <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces tensor parallelism (TP) awareness to the get_tensor function, allowing tensors to be sliced for distributed processing. The implementation correctly leverages torch.chunk for this purpose and includes new parameters tp_rank, tp_size, and split_dim. The accompanying Python test script has been significantly improved to include a comprehensive benchmark and validation for this new TP functionality, covering both row and column parallelism. The changes are well-implemented, but I've identified a potential high-severity issue related to unhandled exceptions during memory allocation that could lead to a crash, and a medium-severity issue with a potentially misleading log message. My review includes suggestions to address these points.
| return pybind11::none(); | ||
| } | ||
|
|
||
| char *exported_data = new char[total_length]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new operator may throw a std::bad_alloc exception on allocation failure. This exception is not caught by the surrounding try...catch block, which only handles pybind11::error_already_set, and would lead to an unhandled exception that terminates the process. The if (!exported_data) check on the next line would not be reached.
To make this check effective and prevent a crash, you should use the nothrow version of new.
| char *exported_data = new char[total_length]; | |
| char *exported_data = new (std::nothrow) char[total_length]; |
| LOG(ERROR) << "Invalid tp_rank " << tp_rank | ||
| << " for tp_size " << tp_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is slightly misleading. It reports tp_size as the boundary for tp_rank, but the check is against chunks_tuple.size(). Since torch.chunk can produce fewer chunks than tp_size (e.g., if the dimension being split is smaller than tp_size), chunks_tuple.size() can be different from tp_size. The log message should reflect the actual boundary being checked for better diagnostics.
| LOG(ERROR) << "Invalid tp_rank " << tp_rank | |
| << " for tp_size " << tp_size; | |
| LOG(ERROR) << "Invalid tp_rank " << tp_rank | |
| << " for chunk count " << chunks_tuple.size(); |
d5a36f2 to
9e154b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds Tensor Parallelism (TP) awareness to the get_tensor method in the Mooncake Store, allowing retrieval of specific tensor slices based on TP rank and size. This enables efficient distributed tensor loading for model parallelism workloads without requiring clients to manually slice tensors.
Key Changes
- Extended
get_tensorAPI with optional TP parameters (tp_rank,tp_size,split_dim) - Implemented tensor slicing using PyTorch's
chunkoperation with contiguous memory guarantee - Added comprehensive TP awareness test cases including row and column parallelism validation
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| mooncake-integration/store/store_py.cpp | Implements TP-aware tensor retrieval with validation, splitting logic, and updated Python bindings |
| scripts/test_tensor_api.py | Refactored benchmark script with new TP awareness tests, improved output formatting, and updated configuration constants |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review for a chance to win a $100 gift card. Take the survey.
scripts/test_tensor_api.py
Outdated
| print(f" ❌ Reconstruction Data Mismatch!") | ||
|
|
||
| print("\n✅ Benchmark finished.") | ||
| print("\n✅ All TP Tests Passed.") |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test prints "✅ All TP Tests Passed" unconditionally at the end of Test 3, even if reconstruction failed or validation checks failed in subtests A or B. This could lead to false positives where the test reports success despite actual failures. Consider tracking test failures with a flag and conditionally printing success/failure messages.
scripts/test_tensor_api.py
Outdated
| avg_get_throughput = (TOTAL_BATCH_SIZE_BYTES * 8) / (avg_get_time * (1024**3)) | ||
| print(f"Average GET Time: {avg_get_time:.4f} s") | ||
| print(f"Average GET Throughput: {avg_get_throughput:.2f} Gbps") | ||
| print(f" Tensor Shape: {t_slice.shape}") |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The print statement outputs "Tensor Shape:" without any rank identifier, making it unclear which rank's shape is being printed. This creates ambiguous output when multiple ranks are printed in sequence. Consider including the rank number in the output, e.g., print(f" Rank {rank} Shape: {t_slice.shape}").
| char *exported_data = new char[total_length]; | ||
| if (!exported_data) { | ||
| py::gil_scoped_acquire acquire_gil; | ||
| LOG(ERROR) << "Invalid data format: insufficient data for " | ||
| "metadata"; | ||
| LOG(ERROR) << "Failed to allocate memory for tensor data"; | ||
| return pybind11::none(); | ||
| } |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation at line 135 checks total_length <= sizeof(TensorMetadata), but the allocation at line 142 occurs without checking if the allocation succeeded (if (!exported_data) check on line 143). In modern C++, new throws std::bad_alloc on failure by default rather than returning nullptr. The null check on line 143 is unreachable unless using new (std::nothrow). Either remove the null check or use new (std::nothrow) for the allocation.
scripts/test_tensor_api.py
Outdated
| dim = (dim // 8) * 8 | ||
|
|
||
| tensors_list = [ | ||
| torch.randn(elements_per_tensor, dtype=torch.float32) | ||
| torch.randn(dim, dim, dtype=torch.float32) | ||
| for _ in range(NUM_TENSORS) | ||
| ] | ||
| keys_list = [f"perf_tensor_{i}" for i in range(NUM_TENSORS)] | ||
| print(f"Data prepared: {NUM_TENSORS} tensors, {TENSOR_SIZE_MB} MB each.") | ||
| keys_list = [f"bench_tensor_{i}" for i in range(NUM_TENSORS)] | ||
| print(f" Created {NUM_TENSORS} tensors of shape [{dim}, {dim}] (approx {TENSOR_SIZE_MB} MB each)") |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The actual tensor size will be approximately (dim * dim * 4) / (1024 * 1024) MB, which may differ from TENSOR_SIZE_MB (64 MB) due to the dimension adjustment at line 141. For example, with TENSOR_SIZE_MB = 64, dim would be calculated as 4096, then adjusted to 4096 (still divisible by 8), resulting in tensors of exactly 64 MB. However, the (dim // 8) * 8 adjustment could cause the actual size to be slightly smaller than intended if dim is not initially divisible by 8. The output message says "(approx {TENSOR_SIZE_MB} MB each)" which is correct, but consider also printing the actual calculated size for clarity.
scripts/test_tensor_api.py
Outdated
| if t_slice.shape[0] != expected_rows: | ||
| print(f" ❌ Rank {rank} shape mismatch! Got {t_slice.shape}, expected [{expected_rows}, ...]") | ||
|
|
||
| if not t_slice.is_contiguous(): | ||
| print(f" ❌ Rank {rank} tensor is NOT contiguous!") |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test validation logic prints error messages but does not terminate or raise exceptions when validations fail (e.g., shape mismatches or non-contiguous tensors). The test continues execution and may print "✅ All TP Tests Passed" even when intermediate validation checks have failed. Consider raising exceptions or calling sys.exit(1) after printing validation failure messages to ensure the test accurately reports failures.
scripts/test_tensor_api.py
Outdated
| if t_slice is None: | ||
| print(f" ❌ Rank {rank} failed.") | ||
| sys.exit(1) | ||
| print(f" Tensor Shape: {t_slice.shape}") |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The print statement outputs "Tensor Shape:" without any rank identifier, making it unclear which rank's shape is being printed. Consider including the rank number in the output, e.g., print(f" Rank {rank} Shape: {t_slice.shape}").
| // on policy | ||
| } | ||
|
|
||
| // Use torch.chunk to split the tensor |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The code uses torch.chunk() which may return fewer chunks than tp_size if the dimension size is not evenly divisible. For example, if a tensor has shape [10, 10] and tp_size=4 with split_dim=0, torch.chunk will return 4 chunks of sizes [3, 3, 2, 2], but if tp_size=7, it may return only 7 chunks with varying sizes. The validation at line 219-224 checks tp_rank >= chunks_tuple.size() which handles this case, but the API documentation should clarify this behavior to users.
| // Use torch.chunk to split the tensor | |
| // Use torch.chunk to split the tensor. | |
| // NOTE: torch.chunk() may return fewer chunks than tp_size if the dimension size | |
| // is not evenly divisible by tp_size. For example, if the tensor's size along | |
| // split_dim is 10 and tp_size is 4, torch.chunk will return 4 chunks of sizes [3, 3, 2, 2]. | |
| // Users should not assume that the number of chunks always equals tp_size. |
| // Get the full tensor first | ||
| pybind11::object tensor = | ||
| torch_module().attr("from_numpy")(np_array); | ||
|
|
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing validation: There's no check to ensure that tp_size > 0. If tp_size is 0 or negative, the code at line 207 will proceed to the chunk operation which could lead to undefined behavior. Add validation to check tp_size > 0 before line 207, and return an error if invalid.
| if (tp_size <= 0) { | |
| LOG(ERROR) << "Invalid tp_size " << tp_size << ": must be > 0"; | |
| return pybind11::none(); | |
| } |
| // Optimization Opportunity: For split_dim=0, we could calculate | ||
| // offsets and only copy the relevant slice to save Host Memory, but | ||
| // that requires complex metadata manipulation. Here we use the | ||
| // robust torch.chunk approach. |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The comment on lines 150-154 mentions an "Optimization Opportunity" for split_dim=0 to copy only the relevant slice. However, this optimization is complex and the current approach using torch.chunk is described as "robust". Consider clarifying whether this optimization is planned for future implementation or if it's just informational. The comment could also benefit from explaining why the torch.chunk approach is preferred despite copying the entire buffer.
| // Optimization Opportunity: For split_dim=0, we could calculate | |
| // offsets and only copy the relevant slice to save Host Memory, but | |
| // that requires complex metadata manipulation. Here we use the | |
| // robust torch.chunk approach. | |
| // Optimization Opportunity (informational): For split_dim=0, we could | |
| // calculate offsets and only copy the relevant slice to save Host Memory. | |
| // However, this would require complex metadata manipulation and is not | |
| // planned for immediate implementation. We use the torch.chunk approach | |
| // because it is robust and simple, ensuring correctness even though it | |
| // may copy more data than strictly necessary. |
scripts/test_tensor_api.py
Outdated
| DEFAULT_MOONCAKE_CONFIG_PATH_ENV = "MOONCAKE_CONFIG_PATH" | ||
| DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB | ||
| DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 MB | ||
| DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 GB |
Copilot
AI
Nov 27, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "2 GB" but the value is 2 * 1024 * 1024 * 1024, which equals 2 GiB (Gibibytes), not GB (Gigabytes). This is inconsistent with the comment on line 15 which correctly uses "GiB".
| DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 GB | |
| DEFAULT_LOCAL_BUFFER_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB |
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
|
Could we show performance data here? @XucSh |
@stmatengss There is a test result to get tensor with 64MB. The get_tensor will get the full size. and the get_tensor_with_tp only get 1/tp |
|
Let us update the python API docs as well: |
Got it. Will merge it after review |
| } | ||
|
|
||
| std::string get_tp_key_name(const std::string &base_key, int rank) { | ||
| return base_key + "_tp_" + std::to_string(rank); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, It is a workaround. Finally, we will embed this metadata within another data area.
| .cast<std::string>() | ||
| .find("Tensor") != std::string::npos)) { | ||
| LOG(ERROR) << "Input is not a PyTorch tensor"; | ||
| return -static_cast<int>(ErrorCode::INVALID_PARAMS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use to_py_ret instead now, cc. @ykwd
|
What I would like to see are some APIs like these. We don't need to split this tensor in Mooncake side. |
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
| << base_keys.size() | ||
| << ", tensors=" << tensors_list.size(); | ||
| return std::vector<int>( | ||
| base_keys.size(), -static_cast<int>(ErrorCode::INVALID_PARAMS)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to_py_ret(ErrorCode::INVALID_PARAMS);
Signed-off-by: Xuchun Shang <[email protected]>
Signed-off-by: Xuchun Shang <[email protected]>
| @staticmethod | ||
| def load_from_env() -> "MooncakeStoreConfig": | ||
| """Load config from a file specified in the environment variable. | ||
| export MOONCAKE_MASTER=10.13.3.232:50051 | ||
| export MOONCAKE_PROTOCOL="rdma" | ||
| export MOONCAKE_DEVICE="" | ||
| export MOONCAKE_TE_META_DATA_SERVER="P2PHANDSHAKE" | ||
| """ | ||
| # other required environment variables... | ||
| """Load configuration from environment variables.""" | ||
| if not os.getenv("MOONCAKE_MASTER"): | ||
| raise ValueError("The environment variable 'MOONCAKE_MASTER' is not set.") | ||
| raise ValueError("Environment variable 'MOONCAKE_MASTER' is not set.") | ||
| return MooncakeStoreConfig( | ||
| local_hostname=os.getenv("LOCAL_HOSTNAME", "localhost"), | ||
| metadata_server=os.getenv("MOONCAKE_TE_META_DATA_SERVER", "P2PHANDSHAKE"), | ||
| global_segment_size=_parse_global_segment_size( | ||
| global_segment_size=parse_global_segment_size( | ||
| os.getenv("MOONCAKE_GLOBAL_SEGMENT_SIZE", DEFAULT_GLOBAL_SEGMENT_SIZE) | ||
| ), | ||
| # Zero copy interface does not need local buffer | ||
| local_buffer_size=DEFAULT_LOCAL_BUFFER_SIZE, | ||
| protocol=os.getenv("MOONCAKE_PROTOCOL", "tcp"), | ||
| device_name=os.getenv("MOONCAKE_DEVICE", ""), | ||
| master_server_address=os.getenv("MOONCAKE_MASTER"), | ||
| master_metrics_port=int( | ||
| os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_MASTER_METRICS_PORT) | ||
| ), | ||
| master_metrics_port=int(os.getenv("MOONCAKE_MASTER_METRICS_PORT", DEFAULT_MASTER_METRICS_PORT)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to rewrite it, import MooncakeStoreConfig from mooncake-wheel/mooncake
| if not all(r == 0 for r in results): | ||
| print(f" Iteration {i+1}: FAILED (rc={results})", file=sys.stderr) | ||
| continue | ||
|
|
||
| elapsed_time = end_time - start_time | ||
| put_times.append(elapsed_time) | ||
|
|
||
| # (total_bytes * 8 bits/byte) / (time * 1024^3 Giga) = Gbps | ||
| throughput_gbps = (TOTAL_BATCH_SIZE_BYTES * 8) / (elapsed_time * (1024**3)) | ||
| print(f" Iteration {i+1}: {elapsed_time:.4f} s ({throughput_gbps:.2f} Gbps)") | ||
|
|
||
| if put_times: | ||
| avg_put_time = np.mean(put_times) | ||
| avg_put_throughput = (TOTAL_BATCH_SIZE_BYTES * 8) / (avg_put_time * (1024**3)) | ||
| print(f"Average PUT Time: {avg_put_time:.4f} s") | ||
| print(f"Average PUT Throughput: {avg_put_throughput:.2f} Gbps") | ||
| else: | ||
| print("PUT test failed to complete.") | ||
|
|
||
| # ---------------------------------------- | ||
| # Test 2: batch_get_tensor | ||
| # ---------------------------------------- | ||
| print(f"\n--- Benchmarking batch_get_tensor ({num_iterations} iterations) ---") | ||
|
|
||
| print(" (Pre-populating data for GET test...)") | ||
| store.remove_all() | ||
| rc = store.batch_put_tensor(keys_list, tensors_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure these tests still exist in current script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's just a rfactor
Description
Type of Change
How Has This Been Tested?
Checklist