@@ -447,25 +447,27 @@ void AscendDirectTransport::workerThread() {
447
447
return ;
448
448
}
449
449
while (running_) {
450
- std::unique_lock<std::mutex> lock (queue_mutex_);
451
- queue_cv_.wait (lock,
452
- [this ] { return !running_ || !slice_queue_.empty (); });
453
- if (!running_) {
454
- break ;
455
- }
456
-
457
- if (!slice_queue_.empty ()) {
458
- auto slice_list = std::move (slice_queue_.front ());
459
- slice_queue_.pop ();
460
- lock.unlock ();
461
-
462
- if (slice_list.empty ()) {
463
- LOG (ERROR)
464
- << " AscendDirectTransport: empty transfer request batch" ;
465
- continue ;
450
+ std::vector<Slice *> slice_list;
451
+ {
452
+ std::unique_lock<std::mutex> lock (queue_mutex_);
453
+ queue_cv_.wait (
454
+ lock, [this ] { return !running_ || !slice_queue_.empty (); });
455
+ if (!running_) {
456
+ break ;
466
457
}
467
-
468
- processSliceList (slice_list);
458
+ slice_list = std::move (slice_queue_.front ());
459
+ slice_queue_.pop ();
460
+ }
461
+ if (slice_list.empty ()) {
462
+ LOG (ERROR) << " AscendDirectTransport: empty transfer request batch" ;
463
+ continue ;
464
+ }
465
+ std::unordered_map<SegmentID, std::vector<Slice *>> seg_to_slices;
466
+ for (auto slice : slice_list) {
467
+ seg_to_slices[slice->target_id ].push_back (slice);
468
+ }
469
+ for (auto &[seg_id, slices] : seg_to_slices) {
470
+ processSliceList (slices);
469
471
}
470
472
}
471
473
LOG (INFO) << " AscendDirectTransport worker thread stopped" ;
@@ -503,7 +505,14 @@ void AscendDirectTransport::processSliceList(
503
505
}
504
506
if (target_adxl_engine_name == local_adxl_engine_name_) {
505
507
VLOG (1 ) << " Target is local, use memory copy." ;
506
- return localCopy (slice_list[0 ]->opcode , slice_list);
508
+ auto start = std::chrono::steady_clock::now ();
509
+ localCopy (slice_list[0 ]->opcode , slice_list);
510
+ uint64_t count = std::chrono::duration_cast<std::chrono::microseconds>(
511
+ std::chrono::steady_clock::now () - start)
512
+ .count ();
513
+ LOG (INFO) << " Copy to local segment: " << target_adxl_engine_name
514
+ << " cost: " << count << " us" ;
515
+ return ;
507
516
}
508
517
int ret = checkAndConnect (target_adxl_engine_name);
509
518
if (ret != 0 ) {
@@ -543,89 +552,95 @@ void AscendDirectTransport::processSliceList(
543
552
544
553
void AscendDirectTransport::localCopy (TransferRequest::OpCode opcode,
545
554
const std::vector<Slice *> &slice_list) {
546
- std::vector<Slice *> async_list ;
547
- for ( auto &slice : slice_list) {
548
- auto local_ptr = slice-> source_addr ;
549
- auto remote_ptr =
550
- reinterpret_cast < void *>(slice-> ascend_direct . dest_addr ) ;
551
- aclrtPtrAttributes attributes;
552
- auto ret = aclrtPointerGetAttributes (slice-> source_addr , &attributes);
553
- if (ret != ACL_ERROR_NONE) {
554
- LOG (ERROR) << " aclrtPointerGetAttributes failed, ret: " << ret;
555
+ aclrtMemcpyKind kind ;
556
+ auto &first_slice = slice_list[ 0 ];
557
+ auto remote_ptr =
558
+ reinterpret_cast < void *>(first_slice-> ascend_direct . dest_addr );
559
+ aclrtPtrAttributes attributes ;
560
+ auto ret = aclrtPointerGetAttributes (first_slice-> source_addr , & attributes) ;
561
+ if ( ret != ACL_ERROR_NONE) {
562
+ LOG (ERROR) << " aclrtPointerGetAttributes failed, ret: " << ret;
563
+ for ( auto &slice : slice_list) {
555
564
slice->markFailed ();
556
- continue ;
557
565
}
558
- aclrtPtrAttributes dst_attributes;
559
- ret = aclrtPointerGetAttributes (remote_ptr, &dst_attributes);
560
- if (ret != ACL_ERROR_NONE) {
561
- LOG (ERROR) << " aclrtPointerGetAttributes failed, ret:" << ret;
566
+ }
567
+ aclrtPtrAttributes dst_attributes;
568
+ ret = aclrtPointerGetAttributes (remote_ptr, &dst_attributes);
569
+ if (ret != ACL_ERROR_NONE) {
570
+ LOG (ERROR) << " aclrtPointerGetAttributes failed, ret:" << ret;
571
+ for (auto &slice : slice_list) {
562
572
slice->markFailed ();
563
- continue ;
564
573
}
565
- if (attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
566
- attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
567
- LOG (ERROR) << " location of local addr is not supported." ;
574
+ }
575
+ if (attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
576
+ attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
577
+ LOG (ERROR) << " location of local addr is not supported." ;
578
+ for (auto &slice : slice_list) {
568
579
slice->markFailed ();
569
- continue ;
570
580
}
571
- if (dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
572
- dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
573
- LOG (ERROR) << " location of remote addr is not supported." ;
581
+ }
582
+ if (dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
583
+ dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
584
+ LOG (ERROR) << " location of remote addr is not supported." ;
585
+ for (auto &slice : slice_list) {
574
586
slice->markFailed ();
575
- continue ;
576
- }
577
- aclrtMemcpyKind kind;
578
- auto len = slice->length ;
579
- if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST &&
580
- dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
581
- ret = aclrtMemcpy (remote_ptr, len, local_ptr, len,
582
- ACL_MEMCPY_HOST_TO_HOST);
583
- if (ret == ACL_ERROR_NONE) {
584
- slice->markSuccess ();
585
- } else {
586
- LOG (ERROR) << " aclrtMemcpyAsync failed, ret:" << ret;
587
- slice->markFailed ();
588
- }
589
- continue ;
590
- } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE &&
591
- dst_attributes.location .type ==
592
- ACL_MEM_LOCATION_TYPE_DEVICE) {
593
- kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
594
- } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
595
- kind = (opcode == TransferRequest::WRITE)
596
- ? ACL_MEMCPY_HOST_TO_DEVICE
597
- : ACL_MEMCPY_DEVICE_TO_HOST;
598
- } else {
599
- kind = (opcode == TransferRequest::WRITE)
600
- ? ACL_MEMCPY_DEVICE_TO_HOST
601
- : ACL_MEMCPY_HOST_TO_DEVICE;
602
587
}
603
- if (opcode == TransferRequest::WRITE) {
604
- ret = aclrtMemcpyAsync (remote_ptr, len, local_ptr, len, kind,
605
- stream_);
588
+ }
589
+ if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST &&
590
+ dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
591
+ kind = ACL_MEMCPY_HOST_TO_HOST;
592
+ } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE &&
593
+ dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE) {
594
+ kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
595
+ } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
596
+ kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE
597
+ : ACL_MEMCPY_DEVICE_TO_HOST;
598
+ } else {
599
+ kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST
600
+ : ACL_MEMCPY_HOST_TO_DEVICE;
601
+ }
602
+ std::vector<void *> void_remote_addrs (slice_list.size ());
603
+ std::vector<void *> void_local_addrs (slice_list.size ());
604
+ std::vector<aclrtMemcpyBatchAttr> attrs (slice_list.size ());
605
+ std::vector<size_t > attrsIds (slice_list.size ());
606
+ std::vector<size_t > sizes (slice_list.size ());
607
+ size_t idx = 0 ;
608
+ for (size_t i = 0 ; i < slice_list.size (); i++) {
609
+ auto device_loc = aclrtMemLocation{
610
+ static_cast <uint32_t >(device_logic_id_),
611
+ aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE};
612
+ auto host_loc = aclrtMemLocation{
613
+ 0 , aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST};
614
+ if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
615
+ attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}};
606
616
} else {
607
- ret = aclrtMemcpyAsync (local_ptr, len, remote_ptr, len, kind,
608
- stream_);
609
- }
610
- if (ret != ACL_ERROR_NONE) {
611
- LOG (ERROR) << " aclrtMemcpyAsync failed, ret:" << ret;
612
- slice->markFailed ();
613
- continue ;
617
+ attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}};
614
618
}
615
- async_list.emplace_back (slice);
619
+ attrsIds[i] = idx++;
620
+ auto &slice = slice_list[i];
621
+ void_local_addrs[i] = slice->source_addr ;
622
+ void_remote_addrs[i] =
623
+ reinterpret_cast <void *>(slice->ascend_direct .dest_addr );
624
+ sizes[i] = slice->length ;
625
+ }
626
+ size_t fail_idx;
627
+ if (opcode == TransferRequest::WRITE) {
628
+ ret = aclrtMemcpyBatch (void_remote_addrs.data (), sizes.data (),
629
+ void_local_addrs.data (), sizes.data (),
630
+ sizes.size (), attrs.data (), attrsIds.data (),
631
+ attrs.size (), &fail_idx);
632
+ } else {
633
+ ret = aclrtMemcpyBatch (void_local_addrs.data (), sizes.data (),
634
+ void_remote_addrs.data (), sizes.data (),
635
+ sizes.size (), attrs.data (), attrsIds.data (),
636
+ attrs.size (), &fail_idx);
616
637
}
617
- auto ret = aclrtSynchronizeStreamWithTimeout (stream_, transfer_timeout_);
618
638
if (ret == ACL_ERROR_NONE) {
619
- for (auto &slice : async_list ) {
639
+ for (auto &slice : slice_list ) {
620
640
slice->markSuccess ();
621
641
}
622
642
} else {
623
- LOG (ERROR) << " Memory copy timeout." ;
624
- ret = aclrtStreamAbort (stream_);
625
- if (ret != ACL_ERROR_NONE) {
626
- LOG (ERROR) << " Failed to abort stream, ret:" << ret;
627
- }
628
- for (auto &slice : async_list) {
643
+ for (auto &slice : slice_list) {
629
644
slice->markFailed ();
630
645
}
631
646
}
@@ -636,8 +651,8 @@ int AscendDirectTransport::checkAndConnect(
636
651
std::lock_guard<std::mutex> lock (connection_mutex_);
637
652
auto it = connected_segments_.find (target_adxl_engine_name);
638
653
if (it != connected_segments_.end ()) {
639
- LOG (INFO ) << " Already connected to target adxl engine: "
640
- << target_adxl_engine_name;
654
+ VLOG ( 1 ) << " Already connected to target adxl engine: "
655
+ << target_adxl_engine_name;
641
656
return 0 ;
642
657
}
643
658
auto status =
0 commit comments