Skip to content

Commit 477b433

Browse files
dboydameta-codesync[bot]
authored andcommitted
fix rank in init
Summary: Currently, it requires passing a rank at creation time. However, rank can only be determined after the communicator is created, so semantically, it is not possible to provide a rank during creation. Even if we ignore semantics and pass a rank—as users often do—this approach only works for non-fault-tolerant scenarios, where the scheduler creates processes and sets the environment variable `RANK`. In fault-tolerant scenarios, we have elastic process creation: processes can be added or removed at any time. This means we encounter the same issue—at creation time, we cannot specify ranks, as ranks can only be determined at initialization, once all participants are known. To address this, we are moving MCCL to a new API where ranks are not provided at creation time. At initialization, we support two options: 1. **User does not care about rank order:** MCCL `init` accepts a `std::unordered_set` of URLs, and MCCL is free to assign ranks internally based on its own considerations. 2. **User wants to specify rank order:** For non-fault-tolerant cases, where ranks are defined by the environment variable `RANK`, the user passes a `std::vector` of URLs to MCCL. MCCL will respect the order of URLs and assign ranks accordingly. Using a vector also ensures rank properties are satisfied: ranks go sequentially from 0 to `nRanks - 1`, with no repetitions or missing ranks. Reviewed By: saifhhasan Differential Revision: D84839780 fbshipit-source-id: c482b47976cf8711fca737c0f107896bc3bfe558
1 parent aa3bd76 commit 477b433

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

comms/ctran/tests/CtranXPlatUtUtils.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ TestCtranCommRAII::TestCtranCommRAII(std::unique_ptr<mccl::McclComm> mcclComm)
248248
std::unique_ptr<TestCtranCommRAII> createDummyCtranComm() {
249249
CHECK_EQ(ctran::utils::commCudaLibraryInit(), commSuccess);
250250
mccl::McclCommCreateOpts mcclCreateOpts{
251-
.rank = 0,
252251
.cudaDeviceId = 0,
253252
.enableFaultTolerance = false,
254253
};
@@ -259,7 +258,7 @@ std::unique_ptr<TestCtranCommRAII> createDummyCtranComm() {
259258
auto initWorkHandle = mcclComm->init(
260259
mccl::InitOpts{
261260
.uuid = uuid,
262-
.urls = {initURL},
261+
.urls = std::unordered_set<mccl::InitURL>{initURL},
263262
});
264263
initWorkHandle->waitCpu();
265264
auto initResult = initWorkHandle->getResult();

comms/ctran/tests/CtranXPlatUtUtils.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ class CtranDistTest : public ::testing::Test {
288288
return globalRank == 0;
289289
}
290290

291-
std::unordered_set<std::string>
291+
std::vector<std::string>
292292
exchangeInitUrls(const std::string& selfUrl, int numRanks, int selfRank) {
293-
std::unordered_set<std::string> res;
293+
std::vector<std::string> res(numRanks);
294294
if (getInitEnvType() == InitEnvType::TCP_STORE) {
295295
std::vector<std::string> rankKeys(numRanks);
296296
const auto keyUid = getTcpStoreKey(TcpStorePhase::INIT);
@@ -305,8 +305,9 @@ class CtranDistTest : public ::testing::Test {
305305
tcpStore_->wait(rankKeys);
306306
if (tcpStore_->check(rankKeys)) {
307307
auto rankUrls = tcpStore_->multiGet(rankKeys);
308-
for (const auto& url : rankUrls) {
309-
res.emplace(std::string(url.begin(), url.end()));
308+
for (int i = 0; i < numRanks; ++i) {
309+
const auto& url = rankUrls.at(i);
310+
res[i] = std::string(url.begin(), url.end());
310311
}
311312
} else {
312313
LOG(FATAL) << "TCPStore key check returned false";
@@ -326,7 +327,9 @@ class CtranDistTest : public ::testing::Test {
326327
MPI_CHAR,
327328
MPI_COMM_WORLD);
328329
for (int i = 0; i < numRanks; ++i) {
329-
res.emplace(std::string(urls.data() + kMaxUrlLen * i));
330+
const char* start = urls.data() + kMaxUrlLen * i;
331+
size_t len = strnlen(start, kMaxUrlLen);
332+
res[i] = std::string(start, len);
330333
}
331334
}
332335
return res;
@@ -339,7 +342,6 @@ class CtranDistTest : public ::testing::Test {
339342
// TODO: refactor mccl comm creation to generic ctran comm creation
340343
COMMCHECK_TEST(ctran::utils::commCudaLibraryInit());
341344
mccl::McclCommCreateOpts opts{
342-
.rank = globalRank,
343345
.cudaDeviceId = cudaDev,
344346
.timeout = std::chrono::seconds(5),
345347
};

0 commit comments

Comments
 (0)