Skip to content

Commit f9a45be

Browse files
arttianezhumeta-codesync[bot]
authored andcommitted
AllReduce ctring enable more datatypes
Summary: Make mccl pg/upper level tests pass, since we are switching to ctring by default in MCCL (D85346277) fp8, bf16 are excluded, as they commonly require casting to higher precision types for proper reduction calculation. Current Ctran handling is incomplete with these types. Reviewed By: dboyda Differential Revision: D85599686 fbshipit-source-id: f04e1ce7a0f8689e56924bdf8f60bea3e57ca27f
1 parent a62a1c1 commit f9a45be

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

comms/ctran/algos/AllReduce/AllReduceRing.cc

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -653,17 +653,25 @@ static commResult_t impl(
653653
algoCtx.opRounds[Op::kSendTrans].totalRounds ||
654654
algoCtx.opRounds[Op::kRecvRedCopy].done <
655655
algoCtx.opRounds[Op::kRecvRedCopy].totalRounds) {
656-
// TODO: enable other data types
657-
switch (op->allreduce.datatype) {
658-
case commFloat32:
659-
case commUint64:
660-
case commInt32:
661-
case commInt8:
662-
break;
663-
default:
664-
throw ctran::utils::Exception(
665-
fmt::format("Unsupported data type {}", op->allreduce.datatype),
666-
commInvalidArgument);
656+
if (op->allreduce.datatype == commInt8 ||
657+
op->allreduce.datatype == commChar ||
658+
op->allreduce.datatype == commUint8 ||
659+
op->allreduce.datatype == commInt32 ||
660+
op->allreduce.datatype == commInt ||
661+
op->allreduce.datatype == commUint32 ||
662+
op->allreduce.datatype == commInt64 ||
663+
op->allreduce.datatype == commUint64 ||
664+
op->allreduce.datatype == commFloat16 ||
665+
op->allreduce.datatype == commHalf ||
666+
op->allreduce.datatype == commFloat32 ||
667+
op->allreduce.datatype == commFloat ||
668+
op->allreduce.datatype == commFloat64 ||
669+
op->allreduce.datatype == commDouble) {
670+
// TODO: enable other data types
671+
} else {
672+
throw ctran::utils::Exception(
673+
fmt::format("Unsupported data type {}", op->allreduce.datatype),
674+
commInvalidArgument);
667675
}
668676
progressSend(args, resource, algoCtx, dataSResps, bufSyncRResps);
669677
HOST_ABORT();

0 commit comments

Comments
 (0)