Skip to content

Commit 8aad9a9

Browse files
Saif Hasanmeta-codesync[bot]
authored andcommitted
Make CtranExRequestImpl.bcast.complete a shared_ptr
Summary: Change `std::atomic_flag complete` in `CtranExRequestImpl.bcast` to be a `std::shared_ptr<std::atomic_flag>` instead of a raw `std::atomic_flag`. This simplifies the ownership model and removes the need for the no-op deleter trick when creating shared_ptrs in the ncclx wrappers. The `CtranExRequestImpl` now owns the flag directly as a `shared_ptr`, and this can be passed to `submitHost` without any special handling. This change maintains the same functionality while providing cleaner code and a more intuitive ownership model where the `CtranExRequestImpl` owns the completion flag. Reviewed By: Regina8023 Differential Revision: D86548364 fbshipit-source-id: 43e1edf1e0640075da8da0f5a19098bfa6995ed0
1 parent 1cda6ae commit 8aad9a9

File tree

9 files changed

+17
-16
lines changed

9 files changed

+17
-16
lines changed

comms/ctran/CtranExImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class CtranExRequestImpl {
8686
} recvSyncCtrl;
8787
struct {
8888
// completion is set by GPE thread and checked by calling thread.
89-
std::atomic_flag complete;
89+
std::shared_ptr<std::atomic_flag> complete;
9090
} bcast;
9191
};
9292

comms/ctran/CtranExRequest.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ void CtranExRequestImpl::initialize(Type type, CtranIb* ctranIb) {
5858
void CtranExRequestImpl::initialize(Type type, CtranComm* ctranComm) {
5959
switch (type) {
6060
case BCAST:
61-
bcast.complete.clear();
61+
bcast.complete = std::make_shared<std::atomic_flag>();
62+
bcast.complete->clear();
6263
break;
6364
default:
6465
FB_CHECKABORT(
@@ -75,7 +76,7 @@ void CtranExRequestImpl::initialize(Type type, CtranComm* ctranComm) {
7576
void CtranExRequestImpl::complete() {
7677
switch (type) {
7778
case BCAST:
78-
bcast.complete.test_and_set();
79+
bcast.complete->test_and_set();
7980
break;
8081
// no op for other types
8182
default:
@@ -113,7 +114,7 @@ void CtranExRequestImpl::atComplete(CtranExRequest* req) {
113114
bool CtranExRequest::isComplete() const {
114115
auto reqImpl = reinterpret_cast<CtranExRequestImpl*>(impl_);
115116
if (reqImpl->type == CtranExRequestImpl::BCAST) {
116-
return reqImpl->bcast.complete.test();
117+
return reqImpl->bcast.complete->test();
117118
}
118119
return reqImpl->ibReq.isComplete();
119120
}
@@ -125,7 +126,7 @@ commResult_t CtranExRequest::test(bool& complete) {
125126
// marks completion. Thus, no need to polling backend progress by the calling
126127
// thread.
127128
if (reqImpl->type == CtranExRequestImpl::BCAST) {
128-
complete = reqImpl->bcast.complete.test();
129+
complete = reqImpl->bcast.complete->test();
129130

130131
// Check if there is any error reported by the GPE thread;
131132
// if so, return the error code.
@@ -162,7 +163,7 @@ commResult_t CtranExRequest::wait() {
162163

163164
if (reqImpl->type == CtranExRequestImpl::BCAST) {
164165
// GPE thread is handling the communcation, wait for it to complete
165-
reqImpl->bcast.complete.wait(true);
166+
reqImpl->bcast.complete->wait(true);
166167
// Check if there is any error reported by the GPE thread;
167168
// if so, return the error code.
168169
if (reqImpl->asyncErr) {

comms/ctran/algos/Broadcast/BroadcastBinomialTree.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ commResult_t CtranAlgo::broadcastBinomialTree(
440440
size_t count,
441441
commDataType_t datatype,
442442
int root,
443-
std::atomic_flag* cpuFlag) {
443+
std::shared_ptr<std::atomic_flag> cpuFlag) {
444444
auto opCount = ctran_->getOpCount();
445445
CTRAN_HOST_COLL_INFO(
446446
broadcastAlgoName(myAlgo).c_str(),

comms/ctran/algos/CtranAlgo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class CtranAlgo {
5757
size_t count,
5858
commDataType_t datatype,
5959
int root,
60-
std::atomic_flag* cpuFlag);
60+
std::shared_ptr<std::atomic_flag> cpuFlag);
6161

6262
commResult_t initTmpBufs();
6363
commResult_t initAllReduceDirectResource(int nBlocks, cudaStream_t stream);

comms/ctran/gpe/CtranGpe.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -389,13 +389,13 @@ commResult_t CtranGpe::submitHost(
389389
std::vector<std::unique_ptr<struct OpElem>> opGroup,
390390
opFunc func,
391391
KernelConfig& kernelConfig,
392-
std::atomic_flag* cpuFlag) {
392+
std::shared_ptr<std::atomic_flag> cpuFlag) {
393393
return this->pimpl->submitHost(
394394
CtranGpeCmd::TypeEnum::GRAPH_ENQUEUE,
395395
std::move(opGroup),
396396
func,
397397
kernelConfig,
398-
cpuFlag);
398+
std::move(cpuFlag));
399399
}
400400

401401
commResult_t CtranGpe::allocKernelElems(

comms/ctran/gpe/CtranGpe.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ class CtranGpe {
392392
std::vector<std::unique_ptr<struct OpElem>> opGroup,
393393
opFunc func,
394394
KernelConfig& kernelConfig,
395-
std::atomic_flag* cpuFlag);
395+
std::shared_ptr<std::atomic_flag> cpuFlag);
396396

397397
// Allocate numElems number of p2pElem objects from internal pool.
398398
// When free objects are not enough, it will be in blocking wait and reclaim

comms/ctran/gpe/CtranGpeImpl.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,13 @@ commResult_t CtranGpe::Impl::submitHost(
394394
std::vector<std::unique_ptr<struct OpElem>> opGroup,
395395
opFunc func,
396396
KernelConfig& kernelConfig,
397-
std::atomic_flag* cpuFlag) {
397+
std::shared_ptr<std::atomic_flag> cpuFlag) {
398398
// Enqueue op to gpeThread if any op is appended
399399
if (!opGroup.empty()) {
400400
class CtranGpeCmd* cmd = new class CtranGpeCmd;
401401
cmd->type = type;
402402
cmd->kernelFlag = nullptr;
403-
cmd->cpuFlag = cpuFlag;
403+
cmd->cpuFlag = std::move(cpuFlag);
404404

405405
if (type == CtranGpeCmd::TypeEnum::GRAPH_ENQUEUE) {
406406
cmd->coll.opGroup = std::move(opGroup);

comms/ctran/gpe/CtranGpeImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class CtranGpeCmd {
124124
// kernelFlag to assist device mem communication
125125
KernelFlagItem* kernelFlag{nullptr};
126126
// cpuFlag to track completion of host mem communication
127-
std::atomic_flag* cpuFlag{nullptr};
127+
std::shared_ptr<std::atomic_flag> cpuFlag{nullptr};
128128

129129
bool persistent{false};
130130

@@ -202,7 +202,7 @@ class CtranGpe::Impl {
202202
std::vector<std::unique_ptr<struct OpElem>> opGroup,
203203
opFunc func,
204204
KernelConfig& kernelConfig,
205-
std::atomic_flag* cpuFlag);
205+
std::shared_ptr<std::atomic_flag> cpuFlag);
206206

207207
// start the GPE thread.
208208
void start();

comms/ncclx/v2_27/meta/wrapper/CtranExComm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ ncclResult_t CtranExComm::broadcast(
144144
count,
145145
ncclToMetaComm(datatype),
146146
root,
147-
&reqImpl->bcast.complete)));
147+
reqImpl->bcast.complete)));
148148

149149
*req = reqPtr;
150150
return ncclSuccess;

0 commit comments

Comments
 (0)