@@ -138,6 +138,11 @@ class VanillaCPU final: public DispatchStub {
138138 const ReduceOptions& opts,
139139 ProcessGroupCCL& pg) override ;
140140
141+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> reduce_scatter_ (std::vector<at::Tensor>& outputTensors,
142+ std::vector<std::vector<at::Tensor>>& inputTensors,
143+ const ReduceScatterOptions& opts,
144+ ProcessGroupCCL& pg_ccl) override ;
145+
141146 c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_scatter_base_ (at::Tensor& outputTensor,
142147 at::Tensor& inputTensor,
143148 const ReduceScatterOptions& opts,
@@ -194,6 +199,11 @@ class VanillaCPU final: public DispatchStub {
194199 std::condition_variable queueProduceCV_;
195200 std::condition_variable queueConsumeCV_;
196201
202+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> _reduce_oop (at::Tensor& outputTensor,
203+ at::Tensor& inputTensor,
204+ const ReduceOptions& opts,
205+ ProcessGroupCCL& pg_ccl);
206+
197207};
198208
199209struct RegisterCPUPMethods {
@@ -388,6 +398,45 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_(std::vecto
388398 return work;
389399}
390400
401+ // _reduce_oop implements an out-of-place reduce procedure.
402+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::_reduce_oop (at::Tensor& outputTensor,
403+ at::Tensor& inputTensor,
404+ const ReduceOptions& opts,
405+ ProcessGroupCCL& pg_ccl) {
406+ const int root = opts.rootRank + opts.rootTensor ;
407+ std::vector<at::Tensor> inputTensors{inputTensor};
408+ std::vector<at::Tensor> outputTensors{outputTensor};
409+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
410+ work = collective<get_ccl_comms, CPUWorkCCL>(
411+ pg_ccl,
412+ inputTensors,
413+ outputTensors,
414+ [=](at::Tensor input,
415+ at::Tensor output,
416+ ccl::reduce_attr attr,
417+ ccl::communicator& comm) {
418+
419+ ccl::event ret_evt;
420+ call_with_lock (c10d::ProcessGroupCCL::globalMutex, [&]() {
421+ CCL_CHECK (ret_evt = ccl::reduce (input.data_ptr (),
422+ output.data_ptr (),
423+ (size_t ) input.numel (),
424+ cclDatatypes.at (input.scalar_type ()),
425+ cclOps.at (opts.reduceOp ),
426+ root,
427+ comm));
428+ });
429+ return ret_evt;
430+
431+ },
432+ c10d::OpType::REDUCE,
433+ " oneccl_bindings_for_pytorch::cpu_work::_reduce_oop" );
434+
435+ work->debugName = std::string (" cpu::_reduce_oop" );
436+ enqueue (work);
437+ return work;
438+ }
439+
391440c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::broadcast_ (std::vector<at::Tensor>& tensors,
392441 const BroadcastOptions &opts,
393442 ProcessGroupCCL& pg) {
@@ -596,6 +645,65 @@ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::gather_(std::vecto
596645 return work;
597646}
598647
648+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::reduce_scatter_ (std::vector<at::Tensor>& outputTensors,
649+ std::vector<std::vector<at::Tensor>>& inputTensors,
650+ const ReduceScatterOptions& opts,
651+ ProcessGroupCCL& pg_ccl) {
652+ checkSingleTensor (outputTensors);
653+ auto outputTensor = outputTensors.back ();
654+ auto inputTensors_ = inputTensors.back ();
655+ bool same_size = check_same_size (inputTensors_);
656+ c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> work;
657+ if (same_size) {
658+ auto inputFlattened = newLikeFlat (inputTensors_);
659+ for (const auto j : c10::irange (inputTensors_.size ())) {
660+ inputFlattened[j].copy_ (inputTensors_[j], true );
661+ }
662+ std::vector<at::Tensor> flattendInputTensors{inputFlattened};
663+ work = collective<get_ccl_comms, CPUWorkCCL>(
664+ pg_ccl,
665+ flattendInputTensors,
666+ outputTensors,
667+ [=](at::Tensor input,
668+ at::Tensor output,
669+ ccl::reduce_attr attr,
670+ ccl::communicator& comm) {
671+
672+ ccl::event ret_evt;
673+ call_with_lock (c10d::ProcessGroupCCL::globalMutex, [&]() {
674+ CCL_CHECK (ret_evt = ccl::reduce_scatter (input.data_ptr (),
675+ output.data_ptr (),
676+ (size_t ) output.numel (),
677+ cclDatatypes.at (input.scalar_type ()),
678+ cclOps.at (opts.reduceOp ),
679+ comm));
680+ });
681+ return ret_evt;
682+
683+ },
684+ c10d::OpType::REDUCE_SCATTER,
685+ " oneccl_bindings_for_pytorch::cpu_work::reduce_scatter" );
686+ work->debugName = std::string (" cpu::reduce_scatter" );
687+ enqueue (work);
688+ return work;
689+
690+ } else {
691+ // Use multiple reduce to simulate reduce_scatter.
692+ const auto num_reduces = inputTensors_.size ();
693+ for (const int i : c10::irange (num_reduces)) {
694+ auto & input = inputTensors_[i];
695+ auto & output = (i == pg_ccl.getRank ()) ? outputTensor : input;
696+ auto reduceOpts = ReduceOptions{
697+ opts.reduceOp ,
698+ static_cast <int64_t >(i),
699+ static_cast <int64_t >(0 ),
700+ opts.timeout };
701+ work = _reduce_oop (output, input, reduceOpts, pg_ccl);
702+ }
703+ return work;
704+ }
705+ }
706+
599707c10::intrusive_ptr<ProcessGroupCCL::AsyncWorkCCL> VanillaCPU::scatter_ (std::vector<at::Tensor>& outputTensors,
600708 std::vector<std::vector<at::Tensor>>& inputTensors,
601709 const ScatterOptions& opts,
0 commit comments