@@ -150,13 +150,13 @@ void BroadcastFromWorker0(ffi::Optional<Tensor> send, bool in_group, Tensor recv
150150 const void * send_data = [&]() -> const void * {
151151 if (is_sender) {
152152 CHECK (send.defined ());
153- CHECK (send.value ().Shape ()-> Product () == recv.Shape ()-> Product ());
153+ CHECK (send.value ().Shape (). Product () == recv.Shape (). Product ());
154154 return send.value ()->data ;
155155 } else {
156156 return nullptr ;
157157 }
158158 }();
159- int64_t numel = recv.Shape ()-> Product ();
159+ int64_t numel = recv.Shape (). Product ();
160160
161161 deviceStream_t stream = ctx->GetDefaultStream ();
162162 NCCL_CALL (ncclBroadcast (send_data, recv->data , numel,
@@ -176,19 +176,19 @@ void ScatterFromWorker0(ffi::Optional<Tensor> send, bool in_group, Tensor recv)
176176 if (is_sender) {
177177 CHECK (send.defined ()) << " ValueError: buffer `send` must be provided when worker_id == 0." ;
178178 Tensor buffer = send.value ();
179- int64_t numel = buffer.Shape ()-> Product ();
179+ int64_t numel = buffer.Shape (). Product ();
180180 CHECK_EQ (numel % num_receiver, 0 ) << " ValueError: Scattering evenly requires that the number "
181181 " of elements in the buffer to be "
182182 " divisible by the number of workers, but got numel = "
183183 << numel << " and " << num_receiver << " workers." ;
184184 DataType dtype (buffer->dtype );
185185 int64_t numel_per_shard = numel / num_receiver;
186186 int64_t bytes_per_shard = numel_per_shard * dtype.bytes ();
187- CHECK_EQ (numel_per_shard, recv.Shape ()-> Product ())
187+ CHECK_EQ (numel_per_shard, recv.Shape (). Product ())
188188 << " ValueError: The number of elements in buffer `recv` must be the same as each shard "
189189 " of "
190190 " buffer `send`. `send.size` is "
191- << numel << " , but `recv.size` is " << recv.Shape ()-> Product () << " ." ;
191+ << numel << " , but `recv.size` is " << recv.Shape (). Product () << " ." ;
192192 NCCL_CALL (ncclGroupStart ());
193193 uint8_t * data = static_cast <uint8_t *>(buffer->data );
194194 for (int i = 0 ; i < num_receiver; ++i) {
@@ -204,7 +204,7 @@ void ScatterFromWorker0(ffi::Optional<Tensor> send, bool in_group, Tensor recv)
204204 }
205205 NCCL_CALL (ncclGroupStart ());
206206 }
207- int64_t numel = recv.Shape ()-> Product ();
207+ int64_t numel = recv.Shape (). Product ();
208208 DataType dtype (recv->dtype );
209209 NCCL_CALL (ncclRecv (recv->data , numel, AsNCCLDataType (dtype), 0 ,
210210 in_group ? ctx->group_comm : ctx->global_comm , stream));
@@ -223,19 +223,19 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional<Tensor> recv) {
223223 if (is_sender) {
224224 CHECK (recv.defined ()) << " ValueError: buffer `recv` must be provided when worker_id == 0." ;
225225 Tensor buffer = recv.value ();
226- int64_t numel = buffer.Shape ()-> Product ();
226+ int64_t numel = buffer.Shape (). Product ();
227227 CHECK_EQ (numel % num_receiver, 0 ) << " ValueError: Gathering evenly requires that the number "
228228 " of elements in the buffer to be "
229229 " divisible by the number of workers, but got numel = "
230230 << numel << " and " << num_receiver << " workers." ;
231231 DataType dtype (buffer->dtype );
232232 int64_t numel_per_shard = numel / num_receiver;
233233 int64_t bytes_per_shard = numel_per_shard * dtype.bytes ();
234- CHECK_EQ (numel_per_shard, send.Shape ()-> Product ())
234+ CHECK_EQ (numel_per_shard, send.Shape (). Product ())
235235 << " ValueError: The number of elements in buffer `send` must be the same as each shard "
236236 " of "
237237 " buffer `recv`. `recv.size` is "
238- << numel << " , but `send.size` is " << send.Shape ()-> Product () << " ." ;
238+ << numel << " , but `send.size` is " << send.Shape (). Product () << " ." ;
239239 NCCL_CALL (ncclGroupStart ());
240240 uint8_t * data = static_cast <uint8_t *>(buffer->data );
241241 for (int i = 0 ; i < num_receiver; ++i) {
@@ -251,7 +251,7 @@ void GatherToWorker0(Tensor send, bool in_group, ffi::Optional<Tensor> recv) {
251251 }
252252 NCCL_CALL (ncclGroupStart ());
253253 }
254- int64_t numel = send.Shape ()-> Product ();
254+ int64_t numel = send.Shape (). Product ();
255255 DataType dtype (send->dtype );
256256 NCCL_CALL (ncclSend (send->data , numel, AsNCCLDataType (dtype), 0 ,
257257 in_group ? ctx->group_comm : ctx->global_comm , stream));
@@ -264,7 +264,7 @@ void RecvFromWorker0(Tensor buffer) {
264264 CHECK_NE (ctx->worker ->worker_id , 0 )
265265 << " ValueError: Worker 0 is not allowed to call RecvFromWorker0." ;
266266 NCCL_CALL (ncclGroupStart ());
267- NCCL_CALL (ncclRecv (buffer->data , buffer.Shape ()-> Product (), AsNCCLDataType (buffer.DataType ()), 0 ,
267+ NCCL_CALL (ncclRecv (buffer->data , buffer.Shape (). Product (), AsNCCLDataType (buffer.DataType ()), 0 ,
268268 ctx->global_comm , stream));
269269 NCCL_CALL (ncclGroupEnd ());
270270}
@@ -278,7 +278,7 @@ void SendToNextGroup(Tensor buffer) {
278278 CHECK_LT (receiver_id, ctx->worker ->num_workers )
279279 << " The current group is already the last group and there is no such a next group." ;
280280 NCCL_CALL (ncclGroupStart ());
281- NCCL_CALL (ncclSend (buffer->data , buffer.Shape ()-> Product (), AsNCCLDataType (buffer.DataType ()),
281+ NCCL_CALL (ncclSend (buffer->data , buffer.Shape (). Product (), AsNCCLDataType (buffer.DataType ()),
282282 receiver_id, ctx->global_comm , stream));
283283 NCCL_CALL (ncclGroupEnd ());
284284}
@@ -292,7 +292,7 @@ void RecvFromPrevGroup(Tensor buffer) {
292292 CHECK_GE (sender_id, 0 )
293293 << " The current group is already the first group and there is no such a previous group." ;
294294 NCCL_CALL (ncclGroupStart ());
295- NCCL_CALL (ncclRecv (buffer->data , buffer.Shape ()-> Product (), AsNCCLDataType (buffer.DataType ()),
295+ NCCL_CALL (ncclRecv (buffer->data , buffer.Shape (). Product (), AsNCCLDataType (buffer.DataType ()),
296296 sender_id, ctx->global_comm , stream));
297297 NCCL_CALL (ncclGroupEnd ());
298298}
@@ -305,7 +305,7 @@ void SendToWorker(Tensor buffer, int receiver_id) {
305305 << " Invalid receiver id " << receiver_id << " . The world size is "
306306 << ctx->worker ->num_workers ;
307307 CHECK_NE (worker_id, receiver_id) << " Cannot send to worker itself." ;
308- NCCL_CALL (ncclSend (buffer->data , buffer.Shape ()-> Product (), AsNCCLDataType (buffer.DataType ()),
308+ NCCL_CALL (ncclSend (buffer->data , buffer.Shape (). Product (), AsNCCLDataType (buffer.DataType ()),
309309 receiver_id, ctx->global_comm , stream));
310310}
311311
@@ -316,7 +316,7 @@ void RecvFromWorker(Tensor buffer, int sender_id) {
316316 CHECK (sender_id >= 0 && sender_id < ctx->worker ->num_workers )
317317 << " Invalid sender id " << sender_id << " . The world size is " << ctx->worker ->num_workers ;
318318 CHECK_NE (worker_id, sender_id) << " Cannot receive from the worker itself." ;
319- NCCL_CALL (ncclRecv (buffer->data , buffer.Shape ()-> Product (), AsNCCLDataType (buffer.DataType ()),
319+ NCCL_CALL (ncclRecv (buffer->data , buffer.Shape (). Product (), AsNCCLDataType (buffer.DataType ()),
320320 sender_id, ctx->global_comm , stream));
321321}
322322
0 commit comments