Skip to content

Commit a62a1c1

Browse files
Regina8023meta-codesync[bot]
authored andcommitted
Refactor cudagraph-aware AllToAllvDynamic: move cudagraph-aware out from GPE submit path
Summary: Similar to AllToAll refactor, moved cudagraph prepare function from gpe to algo. This diff also moves `prepareCudagraphAware*` implementation to `CudaGraphUtilsImpl.cc` since `PersistentObj` is a variant type: if implement the function in corresponding Pimpl.cc file, they need to include the header files of all possible types part of the variant (alltoallPImpl.cc needs to include alltoallvdynamicP.h, and vice versa). So moved all impl to 1 file to only include all headers once. Reviewed By: minsii Differential Revision: D85492033 fbshipit-source-id: d53c833b78e60ee5671355f5c9b18269453d3331
1 parent 9d55211 commit a62a1c1

File tree

8 files changed

+118
-100
lines changed

8 files changed

+118
-100
lines changed

comms/ctran/algos/AllToAll/AllToAllPImpl.cc

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -391,32 +391,4 @@ commResult_t AlgoImpl::updatePersistentFuncAndOp(
391391
(void*)op);
392392
return commSuccess;
393393
}
394-
395-
commResult_t prepareCudagraphAwareAllToAll(
396-
opFunc& opFunc,
397-
struct OpElem* op,
398-
PersistentObj& pObj) {
399-
pObj = std::make_unique<AlgoImpl>(op->comm_, op->stream);
400-
auto algoImplPtr = std::get<std::unique_ptr<AlgoImpl>>(pObj).get();
401-
if (!algoImplPtr) {
402-
return commSystemError;
403-
}
404-
405-
FB_COMMCHECK(algoImplPtr->setPArgs(
406-
op->alltoall.recvbuff,
407-
op->alltoall.count * op->comm_->statex_->nRanks(),
408-
true /* skipCtrlMsg */,
409-
op->alltoall.datatype));
410-
411-
// Exchange mem handles and record in pArgs. This will not be captured
412-
// by cudagraph.
413-
FB_COMMCHECK(algoImplPtr->init());
414-
415-
// Replace gpe func by the persistent version (skip exchanging mem
416-
// handle); and OpGroup by the persistent op which has the remote
417-
// handles recorded.
418-
419-
FB_COMMCHECK(algoImplPtr->updatePersistentFuncAndOp(opFunc, op));
420-
return commSuccess;
421-
}
422394
} // namespace ctran::alltoallp

comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

33
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h"
4+
#include "Types.h"
45
#include "comms/ctran/CtranComm.h"
56
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicCommon.h"
67
#include "comms/ctran/algos/CtranAlgo.h"
@@ -224,22 +225,16 @@ commResult_t AlgoImpl::init() {
224225

225226
commResult_t AlgoImpl::updatePersistFuncAndOp(
226227
opFunc& opFunc,
227-
std::vector<std::unique_ptr<struct OpElem>>& opGroup,
228228
struct OpElem* op) {
229229
opFunc = gpeFn;
230-
auto new_op = std::make_unique<OpElem>(op);
231-
// The original op is not needed and will/may be destroyed. So set kElem to
232-
// nullptr to avoid it be freed.
233-
op->alltoallv_dynamic.kElem = nullptr;
234230
// FIXME: only support split_non_contig for now
235-
new_op->type = OpElem::opType::ALLTOALLV_DYNAMIC_SPLIT_NON_CONTIG_P;
236-
new_op->alltoallv_dynamic.pArgs = &pArgs;
237-
opGroup.push_back(std::move(new_op));
231+
op->type = OpElem::opType::ALLTOALLV_DYNAMIC_SPLIT_NON_CONTIG_P;
232+
op->alltoallv_dynamic.pArgs = &pArgs;
238233
CLOGF_TRACE(
239234
COLL,
240235
"AllToAllvDynamicP: rank {} updated op to {} and gpeFn to persistent version.",
241236
comm_->statex_->rank(),
242-
(void*)opGroup.front().get());
237+
(void*)op);
243238
return commSuccess;
244239
}
245240
} // namespace ctran::alltoallvdynamicp

comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,13 @@
33
#pragma once
44

55
#include <folly/synchronization/CallOnce.h>
6+
#include "Types.h"
67
#include "comms/ctran/CtranComm.h"
78
#include "comms/ctran/gpe/CtranGpe.h"
89
#include "comms/ctran/mapper/CtranMapperTypes.h"
910
#include "comms/utils/cvars/nccl_cvars.h"
1011

1112
namespace ctran::alltoallvdynamicp {
12-
struct PersistArgs {
13-
std::vector<void*> recvbuffs;
14-
std::vector<void*> recvHdls;
15-
size_t maxSendCount;
16-
size_t maxRecvCount;
17-
commDataType_t datatype;
18-
std::vector<void*> remoteRecvBuffs;
19-
std::vector<struct CtranMapperRemoteAccessKey> remoteAccessKeys;
20-
};
21-
2213
class AlgoImpl {
2314
public:
2415
PersistArgs pArgs;
@@ -29,10 +20,7 @@ class AlgoImpl {
2920

3021
commResult_t init();
3122

32-
commResult_t updatePersistFuncAndOp(
33-
opFunc& opFunc,
34-
std::vector<std::unique_ptr<struct OpElem>>& opGroup,
35-
struct OpElem* op);
23+
commResult_t updatePersistFuncAndOp(opFunc& opFunc, struct OpElem* op);
3624

3725
static inline const std::string algoName(enum NCCL_ALLTOALL_ALGO algo) {
3826
switch (algo) {
@@ -47,4 +35,9 @@ class AlgoImpl {
4735
CtranComm* comm_{nullptr};
4836
cudaStream_t stream_{nullptr};
4937
};
38+
39+
commResult_t prepareCudagraphAwareAllToAllvDynamic(
40+
opFunc& opFunc,
41+
struct OpElem* op,
42+
PersistentObj& pObj);
5043
} // namespace ctran::alltoallvdynamicp

comms/ctran/algos/AllToAll/AllToallvDynamicSplitNonContig.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
#include <cuda_fp16.h>
44
#include <cstddef>
55

6+
#include "Types.h"
67
#include "comms/ctran/CtranComm.h"
78
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicCommon.h"
9+
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h"
810
#include "comms/ctran/algos/CtranAlgo.h"
911
#include "comms/ctran/gpe/CtranGpe.h"
1012

@@ -88,11 +90,19 @@ commResult_t ctranAlltoallvDynamicSplitNonContig(
8890
XCHECK(alltoallvDynamicSplitNonContigKerns.contains(datatype))
8991
<< "alltoallvDynamicSplitNonContigKerns does not contain datatype "
9092
<< datatype;
93+
94+
ctran::PreLaunchGraphPrepareFn graphPrepareFn = nullptr;
95+
if (NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE) {
96+
graphPrepareFn =
97+
ctran::alltoallvdynamicp::prepareCudagraphAwareAllToAllvDynamic;
98+
}
9199
FB_COMMCHECK(comm->ctran_->gpe->submit(
92100
std::move(opGroup),
93101
opIbImpl,
94102
config,
95-
alltoallvDynamicSplitNonContigKerns.at(datatype)));
103+
alltoallvDynamicSplitNonContigKerns.at(datatype),
104+
std::nullopt, /* timeout */
105+
graphPrepareFn));
96106

97107
return commSuccess;
98108
}

comms/ctran/algos/AllToAll/Types.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#include "comms/ctran/mapper/CtranMapperTypes.h"
77
#include "comms/utils/commSpecs.h"
88

9-
namespace ctran::alltoallp {
9+
namespace ctran {
10+
namespace alltoallp {
1011
struct PersistArgs {
1112
void* recvbuff;
1213
void* recvHdl;
@@ -18,4 +19,19 @@ struct PersistArgs {
1819
};
1920

2021
class AlgoImpl;
21-
} // namespace ctran::alltoallp
22+
} // namespace alltoallp
23+
24+
namespace alltoallvdynamicp {
25+
struct PersistArgs {
26+
std::vector<void*> recvbuffs;
27+
std::vector<void*> recvHdls;
28+
size_t maxSendCount;
29+
size_t maxRecvCount;
30+
commDataType_t datatype;
31+
std::vector<void*> remoteRecvBuffs;
32+
std::vector<struct CtranMapperRemoteAccessKey> remoteAccessKeys;
33+
};
34+
35+
class AlgoImpl;
36+
} // namespace alltoallvdynamicp
37+
} // namespace ctran
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#include "comms/ctran/CtranComm.h"
2+
#include "comms/ctran/algos/AllToAll/AllToAllPImpl.h"
3+
#include "comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.h"
4+
#include "comms/ctran/algos/AllToAll/Types.h"
5+
6+
namespace ctran {
7+
namespace alltoallp {
8+
commResult_t prepareCudagraphAwareAllToAll(
9+
opFunc& opFunc,
10+
struct OpElem* op,
11+
PersistentObj& pObj) {
12+
pObj = std::make_unique<AlgoImpl>(op->comm_, op->stream);
13+
auto algoImplPtr = std::get<std::unique_ptr<AlgoImpl>>(pObj).get();
14+
if (!algoImplPtr) {
15+
return commSystemError;
16+
}
17+
18+
FB_COMMCHECK(algoImplPtr->setPArgs(
19+
op->alltoall.recvbuff,
20+
op->alltoall.count * op->comm_->statex_->nRanks(),
21+
true /* skipCtrlMsg */,
22+
op->alltoall.datatype));
23+
24+
// Exchange mem handles and record in pArgs. This will not be captured
25+
// by cudagraph.
26+
FB_COMMCHECK(algoImplPtr->init());
27+
28+
// Replace gpe func by the persistent version (skip exchanging mem
29+
// handle); and OpGroup by the persistent op which has the remote
30+
// handles recorded.
31+
32+
FB_COMMCHECK(algoImplPtr->updatePersistentFuncAndOp(opFunc, op));
33+
return commSuccess;
34+
}
35+
36+
} // namespace alltoallp
37+
namespace alltoallvdynamicp {
38+
commResult_t prepareCudagraphAwareAllToAllvDynamic(
39+
opFunc& opFunc,
40+
struct OpElem* op,
41+
PersistentObj& pObj) {
42+
pObj = std::make_unique<AlgoImpl>(op->comm_, op->stream);
43+
auto algoImplPtr = std::get<std::unique_ptr<AlgoImpl>>(pObj).get();
44+
if (!algoImplPtr) {
45+
return commSystemError;
46+
}
47+
48+
const int nRanks = op->comm_->statex_->nRanks();
49+
std::vector<void*> recvbuffs(nRanks);
50+
for (int i = 0; i < nRanks; i++) {
51+
recvbuffs[i] = op->alltoallv_dynamic.recvbuffs[i];
52+
}
53+
// FIXME: confirm if sendbuffs are also persistent, so we don't need to
54+
// search handle for sendbuffs every time
55+
algoImplPtr->pArgs = {
56+
.recvbuffs = recvbuffs,
57+
.maxSendCount = op->alltoallv_dynamic.maxSendcount,
58+
.maxRecvCount = op->alltoallv_dynamic.maxRecvcount,
59+
.datatype = op->alltoallv_dynamic.datatype,
60+
};
61+
62+
// Exchange mem handles and record in pArgs. This will not be captured
63+
// by cudagraph.
64+
FB_COMMCHECK(algoImplPtr->init());
65+
66+
// Replace gpe func by the persistent version (skip exchanging mem
67+
// handle); and OpGroup by the persistent op which has the remote
68+
// handles recorded.
69+
70+
FB_COMMCHECK(algoImplPtr->updatePersistFuncAndOp(opFunc, op));
71+
return commSuccess;
72+
}
73+
} // namespace alltoallvdynamicp
74+
} // namespace ctran

comms/ctran/gpe/CtranGpe.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ typedef commResult_t (*opFunc)(
2222
const std::vector<std::unique_ptr<struct OpElem>>& opGroup);
2323

2424
namespace ctran {
25-
using PersistentObj =
26-
std::variant<std::monostate, std::unique_ptr<ctran::alltoallp::AlgoImpl>>;
25+
using PersistentObj = std::variant<
26+
std::monostate,
27+
std::unique_ptr<alltoallp::AlgoImpl>,
28+
std::unique_ptr<alltoallvdynamicp::AlgoImpl>>;
2729
using PreLaunchGraphPrepareFn =
2830
commResult_t (*)(opFunc& opFunc, struct OpElem* op, PersistentObj& pObj);
2931
} // namespace ctran

comms/ctran/gpe/CtranGpeImpl.cc

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,6 @@ static std::unordered_map<KernelConfig::KernelType, const std::string>
4646
"AllToAllvDynamicSplitNonContig"},
4747
};
4848

49-
namespace {
50-
commResult_t updateGraphAwareAlltoAllvDynamicCmd(CtranGpeCmd* cmd) {
51-
auto op = cmd->coll.opGroup.front().get();
52-
ctran::alltoallvdynamicp::AlgoImpl* algo =
53-
new ctran::alltoallvdynamicp::AlgoImpl(op->comm_, op->stream);
54-
if (!algo) {
55-
return commSystemError;
56-
}
57-
const int nRanks = op->comm_->statex_->nRanks();
58-
std::vector<void*> recvbuffs(nRanks);
59-
for (int i = 0; i < nRanks; i++) {
60-
recvbuffs[i] = op->alltoallv_dynamic.recvbuffs[i];
61-
}
62-
// FIXME: confirm if sendbuffs are also persistent, so we don't need to
63-
// search handle for sendbuffs every time
64-
algo->pArgs = {
65-
.recvbuffs = recvbuffs,
66-
.maxSendCount = op->alltoallv_dynamic.maxSendcount,
67-
.maxRecvCount = op->alltoallv_dynamic.maxRecvcount,
68-
.datatype = op->alltoallv_dynamic.datatype,
69-
};
70-
// Exchange mem handles and record in pArgs. This will not be captured
71-
// by cudagraph.
72-
FB_COMMCHECK(algo->init());
73-
74-
// Replace gpe func by the persistent version (skip exchanging mem
75-
// handle); and OpGroup by the persistent op which has the remote
76-
// handles recorded.
77-
std::vector<std::unique_ptr<struct OpElem>> newOpGroup;
78-
FB_COMMCHECK(algo->updatePersistFuncAndOp(cmd->coll.func, newOpGroup, op));
79-
cmd->coll.opGroup = std::move(newOpGroup);
80-
return commSuccess;
81-
}
82-
} // namespace
83-
8449
CtranGpe::Impl::Impl() {
8550
this->kernelFlagPool = std::unique_ptr<KernelFlagPool>(
8651
new KernelFlagPool(NCCL_CTRAN_NUM_KERNEL_FLAGS));
@@ -254,15 +219,6 @@ commResult_t CtranGpe::Impl::submit(
254219
if (streamCaptureInfo.status == cudaStreamCaptureStatusActive) {
255220
FB_COMMCHECK(preLaunchGraphPrepare(cmd, graphPrepareFn));
256221
struct cmdCbPlan* plan = new struct cmdCbPlan;
257-
// cudagraph-aware alltoall: transfer alltoall to alltoallPersistent for
258-
// perf optimization
259-
auto op = cmd->coll.opGroup.front().get();
260-
if (NCCL_CTRAN_ALLTOALL_CUDAGRAPH_AWARE_ENABLE &&
261-
op->type == OpElem::opType::ALLTOALLV_DYNAMIC_SPLIT_NON_CONTIG) {
262-
// FIXME: this should control by hints passed from user instead of CVAR
263-
// so we can have per-collective control
264-
updateGraphAwareAlltoAllvDynamicCmd(cmd);
265-
}
266222
plan->cmd = cmd;
267223
plan->gpe = this->gpe;
268224
cmd->persistent = true;

0 commit comments

Comments
 (0)