Skip to content

Commit 9d55211

Browse files
Ben Carvermeta-codesync[bot]
authored andcommitted
Raise exception instead of deadlocking based on minimum elem count validation for ctring AllReduce
Summary: The `ctring` algorithm for the `AllReduce` collective has a bug that causes it to hang when the message size is too small. Based on my empirical testing, this appears to occur when the message size is below 8 bytes. ## Exception Details When the message size is less than 8 bytes, the code: 1. Throws an `std::runtime_error` with a clear, informative error message 2. Logs the error using `XLOGF(ERR, ...)` before throwing 3. Provides actionable guidance to the user (use a larger message size or different algorithm) ## Error Message The error message includes: * The actual message size received * The count and datatype size for debugging * Clear explanation that this would cause a hang * Suggestions for remediation Ultimately, this is a short-term fix that will prevent users from experiencing a hang when using the `ctring` `AllReduce` algorithm with message sizes below 8 bytes. Instead, they'll get a clear error message explaining the issue and how to work around it. Reviewed By: arttianezhu Differential Revision: D84266985 fbshipit-source-id: ecc340ebaa1bacbce28662d9193f73ed2780c180
1 parent d47dc93 commit 9d55211

File tree

3 files changed

+190
-5
lines changed

3 files changed

+190
-5
lines changed

comms/ctran/algos/AllReduce/AllReduceRing.cc

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,35 @@ commResult_t ctranAllReduceRing(
737737
CtranComm* comm,
738738
cudaStream_t stream,
739739
std::optional<std::chrono::milliseconds> timeout) {
740+
// Check for minimum message size requirement for ctring algorithm.
741+
// The ctring algorithm uses a ring-based approach that shards data across all
742+
// ranks. Each rank must have at least one element in its shard to avoid empty
743+
// chunk transfers that can lead to synchronization deadlocks. Therefore, we
744+
// need at least nRanks elements.
745+
const auto& statex = comm->statex_.get();
746+
const auto nRanks = statex->nRanks();
747+
const auto rank = statex->rank();
748+
const size_t typeSize = static_cast<size_t>(commTypeSize(datatype));
749+
const size_t minRequiredElements = nRanks;
750+
const size_t minRequiredBytes = minRequiredElements * typeSize;
751+
752+
if (count < minRequiredElements) {
753+
std::string errorMsg = fmt::format(
754+
"ctring algorithm requires at least {} elements ({} bytes) for {} ranks, "
755+
"but rank {} got {} elements ({} bytes) with datatype size={} bytes. "
756+
"Each rank needs at least one element per shard. "
757+
"Please use a larger message size or a different allreduce algorithm (e.g., ctdirect).",
758+
minRequiredElements,
759+
minRequiredBytes,
760+
nRanks,
761+
rank,
762+
count,
763+
count * typeSize,
764+
typeSize);
765+
CLOGF(ERR, "{}", errorMsg);
766+
throw ctran::utils::Exception(errorMsg, commInvalidArgument);
767+
}
768+
740769
auto opCount = comm->ctran_->getOpCount();
741770
CTRAN_REDCOLL_INFO(
742771
allReduceAlgoName(ctran::allreduce::ring::myAlgo),
@@ -752,10 +781,6 @@ commResult_t ctranAllReduceRing(
752781
std::vector<std::unique_ptr<struct OpElem>> opGroup;
753782
std::unique_ptr<struct OpElem> op;
754783

755-
const auto& statex = comm->statex_.get();
756-
const auto nRanks = statex->nRanks();
757-
const auto rank = statex->rank();
758-
759784
FB_CHECKTHROW(
760785
typeToFunc.contains(std::make_pair(datatype, redOp)),
761786
"typeToFunc does not contain datatype {} with op {}",

comms/ctran/tests/CtranAllReduceTest.cc

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
#include "comms/ctran/tests/CtranStandaloneUTUtils.h"
1515
#include "comms/utils/cvars/nccl_cvars.h"
1616

17+
#include "comms/ctran/algos/AllReduce/AllReduceImpl.h"
18+
1719
namespace ctran::testing {
1820

1921
using AllReduceTestParam = std::tuple<std::string, enum NCCL_ALLREDUCE_ALGO>;
22+
using AllReduceMinMsgSizeTestParam = std::tuple<size_t, commDataType_t>;
23+
24+
enum class CtranAllReduceRingMinSizeTestOpt {
25+
expect_sufficient,
26+
expect_insufficient,
27+
};
2028

2129
class CtranAllReduceTest
2230
: public CtranStandaloneMultiRankBaseTest,
@@ -250,4 +258,156 @@ INSTANTIATE_TEST_SUITE_P(
250258
return std::get<0>(info.param);
251259
});
252260

261+
// Test fixture for ctring minimum message size validation
262+
class CtranAllReduceRingMinSizeTest
263+
: public CtranStandaloneMultiRankBaseTest,
264+
public ::testing::WithParamInterface<AllReduceMinMsgSizeTestParam> {
265+
protected:
266+
static constexpr int kDefaultNumRanks = 4;
267+
static_assert(kDefaultNumRanks % 2 == 0);
268+
static constexpr commRedOp_t kReduceOpType = commSum;
269+
270+
void SetUp() override {
271+
setenv("NCCL_COMM_STATE_DEBUG_TOPO", "nolocal", 1);
272+
setenv("NCCL_IGNORE_TOPO_LOAD_FAILURE", "1", 1);
273+
CtranStandaloneMultiRankBaseTest::SetUp();
274+
}
275+
276+
void startWorkers(int numRanks = kDefaultNumRanks) {
277+
std::vector<std::shared_ptr<::ctran::utils::Abort>> aborts;
278+
aborts.reserve(numRanks);
279+
for (int i = 0; i < numRanks; ++i) {
280+
aborts.push_back(ctran::utils::createAbort(/*enabled=*/true));
281+
}
282+
CtranStandaloneMultiRankBaseTest::startWorkers(numRanks, /*aborts=*/aborts);
283+
}
284+
285+
void runTest(
286+
size_t count,
287+
commDataType_t dt,
288+
enum CtranAllReduceRingMinSizeTestOpt testOpt,
289+
int numRanks = kDefaultNumRanks) {
290+
startWorkers(numRanks);
291+
for (int rank = 0; rank < numRanks; ++rank) {
292+
run(rank, [this, count, dt, testOpt](PerRankState& state) {
293+
ASSERT_TRUE(ctranAllReduceSupport(state.ctranComm.get()));
294+
295+
size_t bufferSize = count * commTypeSize(dt);
296+
if (bufferSize < CTRAN_MIN_REGISTRATION_SIZE) {
297+
bufferSize = CTRAN_MIN_REGISTRATION_SIZE;
298+
}
299+
300+
void* srcHandle;
301+
void* dstHandle;
302+
ASSERT_EQ(
303+
commSuccess,
304+
state.ctranComm->ctran_->commRegister(
305+
state.srcBuffer, bufferSize, &srcHandle));
306+
ASSERT_EQ(
307+
commSuccess,
308+
state.ctranComm->ctran_->commRegister(
309+
state.dstBuffer, bufferSize, &dstHandle));
310+
311+
if (testOpt == CtranAllReduceRingMinSizeTestOpt::expect_sufficient) {
312+
// Should not throw when count >= nRanks
313+
EXPECT_NO_THROW({
314+
auto res = ctranAllReduceRing(
315+
state.srcBuffer,
316+
state.dstBuffer,
317+
count,
318+
dt,
319+
kReduceOpType,
320+
state.ctranComm.get(),
321+
state.stream);
322+
EXPECT_EQ(res, commSuccess);
323+
});
324+
} else {
325+
// Expect ctran::utils::Exception when count < nRanks
326+
EXPECT_THROW(
327+
{
328+
ctranAllReduceRing(
329+
state.srcBuffer,
330+
state.dstBuffer,
331+
count,
332+
dt,
333+
kReduceOpType,
334+
state.ctranComm.get(),
335+
state.stream);
336+
},
337+
ctran::utils::Exception);
338+
}
339+
340+
// ensure async execution completion and no error
341+
EXPECT_EQ(cudaSuccess, cudaStreamSynchronize(state.stream));
342+
343+
// deregistering will happen after streamSync below
344+
ASSERT_EQ(
345+
commSuccess, state.ctranComm->ctran_->commDeregister(dstHandle));
346+
ASSERT_EQ(
347+
commSuccess, state.ctranComm->ctran_->commDeregister(srcHandle));
348+
});
349+
}
350+
}
351+
};
352+
353+
TEST_P(CtranAllReduceRingMinSizeTest, InsufficientElements_1Element) {
354+
auto [numRanks, dt] = GetParam();
355+
ASSERT_FALSE(numRanks <= 1) << "Need at least 2 ranks for this test";
356+
runTest(
357+
1, dt, CtranAllReduceRingMinSizeTestOpt::expect_insufficient, numRanks);
358+
}
359+
360+
TEST_P(CtranAllReduceRingMinSizeTest, InsufficientElements_NRanksMinus1) {
361+
auto [numRanks, dt] = GetParam();
362+
ASSERT_FALSE(numRanks <= 1) << "Need at least 2 ranks for this test";
363+
runTest(
364+
numRanks - 1,
365+
dt,
366+
CtranAllReduceRingMinSizeTestOpt::expect_insufficient,
367+
numRanks);
368+
}
369+
370+
TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_ExactlyNRanks) {
371+
auto [numRanks, dt] = GetParam();
372+
XLOG(INFO) << "SufficientElements_ExactlyNRanks: numRanks: " << numRanks
373+
<< ", dt: " << dt;
374+
runTest(
375+
numRanks,
376+
dt,
377+
CtranAllReduceRingMinSizeTestOpt::expect_sufficient,
378+
numRanks);
379+
}
380+
381+
TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_NRanksPlus1) {
382+
auto [numRanks, dt] = GetParam();
383+
runTest(
384+
numRanks + 1,
385+
dt,
386+
CtranAllReduceRingMinSizeTestOpt::expect_sufficient,
387+
numRanks);
388+
}
389+
390+
TEST_P(CtranAllReduceRingMinSizeTest, SufficientElements_LargeMessage) {
391+
auto [numRanks, dt] = GetParam();
392+
runTest(
393+
1024, dt, CtranAllReduceRingMinSizeTestOpt::expect_sufficient, numRanks);
394+
}
395+
396+
INSTANTIATE_TEST_SUITE_P(
397+
AllDataTypes,
398+
CtranAllReduceRingMinSizeTest,
399+
::testing::Values(
400+
std::make_tuple<>(2, commFloat),
401+
std::make_tuple<>(2, commInt32),
402+
std::make_tuple<>(2, commInt8),
403+
std::make_tuple<>(4, commFloat),
404+
std::make_tuple<>(4, commInt32),
405+
std::make_tuple<>(4, commInt8),
406+
std::make_tuple<>(6, commFloat),
407+
std::make_tuple<>(6, commInt32),
408+
std::make_tuple<>(6, commInt8),
409+
std::make_tuple<>(8, commFloat),
410+
std::make_tuple<>(8, commInt32),
411+
std::make_tuple<>(8, commInt8)));
412+
253413
} // namespace ctran::testing

comms/ctran/tests/CtranStandaloneUTUtils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void initOnce() {
5656
} // namespace
5757

5858
void CtranStandaloneBaseTest::setupBase() {
59-
setenv("NCCL_CTRAN_ENABLE", "INFO", 1);
59+
setenv("NCCL_CTRAN_ENABLE", "1", 1);
6060
setenv("NCCL_DEBUG", "INFO", 1);
6161

6262
rank = 0;

0 commit comments

Comments
 (0)