3333#include " transport/transport.h"
3434
3535namespace mooncake {
36- AscendDirectTransport::AscendDirectTransport () : running_(false ) {}
36+ AscendDirectTransport::AscendDirectTransport ()
37+ : running_(false ), thread_pool_(4 ) {}
3738
3839AscendDirectTransport::~AscendDirectTransport () {
3940 LOG (INFO) << " AscendDirectTransport destructor called" ;
@@ -420,6 +421,11 @@ int AscendDirectTransport::allocateLocalSegmentID() {
420421 LOG (ERROR) << " Call aclrtGetCurrentContext failed, ret: " << ret;
421422 return ret;
422423 }
424+ ret = aclrtGetDevice (&device_id_);
425+ if (ret) {
426+ LOG (ERROR) << " Call aclrtGetDevice failed, ret: " << ret;
427+ return ret;
428+ }
423429 desc->rank_info .hostIp = host_ip;
424430 int sockfd;
425431 desc->rank_info .hostPort = findAvailableTcpPort (sockfd);
@@ -447,25 +453,35 @@ void AscendDirectTransport::workerThread() {
447453 return ;
448454 }
449455 while (running_) {
450- std::unique_lock<std::mutex> lock (queue_mutex_);
451- queue_cv_.wait (lock,
452- [this ] { return !running_ || !slice_queue_.empty (); });
453- if (!running_) {
454- break ;
455- }
456-
457- if (!slice_queue_.empty ()) {
458- auto slice_list = std::move (slice_queue_.front ());
459- slice_queue_.pop ();
460- lock.unlock ();
461-
462- if (slice_list.empty ()) {
463- LOG (ERROR)
464- << " AscendDirectTransport: empty transfer request batch" ;
465- continue ;
456+ std::vector<Slice *> slice_list;
457+ {
458+ std::unique_lock<std::mutex> lock (queue_mutex_);
459+ queue_cv_.wait (
460+ lock, [this ] { return !running_ || !slice_queue_.empty (); });
461+ if (!running_) {
462+ break ;
466463 }
467-
468- processSliceList (slice_list);
464+ slice_list = std::move (slice_queue_.front ());
465+ slice_queue_.pop ();
466+ }
467+ if (slice_list.empty ()) {
468+ LOG (ERROR) << " AscendDirectTransport: empty transfer request batch" ;
469+ continue ;
470+ }
471+ std::unordered_map<SegmentID, std::vector<Slice *>> seg_to_slices;
472+ for (auto slice : slice_list) {
473+ seg_to_slices[slice->target_id ].push_back (slice);
474+ }
475+ for (auto &[seg_id, slices] : seg_to_slices) {
476+ thread_pool_.enqueue ([this , slices] {
477+ auto ret = aclrtSetCurrentContext (rt_context_);
478+ if (ret) {
479+ LOG (ERROR)
480+ << " Call aclrtSetCurrentContext failed, ret: " << ret;
481+ return ;
482+ }
483+ processSliceList (slices);
484+ });
469485 }
470486 }
471487 LOG (INFO) << " AscendDirectTransport worker thread stopped" ;
@@ -503,7 +519,14 @@ void AscendDirectTransport::processSliceList(
503519 }
504520 if (target_adxl_engine_name == local_adxl_engine_name_) {
505521 VLOG (1 ) << " Target is local, use memory copy." ;
506- return localCopy (slice_list[0 ]->opcode , slice_list);
522+ auto start = std::chrono::steady_clock::now ();
523+ localCopy (slice_list[0 ]->opcode , slice_list);
524+ uint64_t count = std::chrono::duration_cast<std::chrono::microseconds>(
525+ std::chrono::steady_clock::now () - start)
526+ .count ();
527+ LOG (INFO) << " Copy to local segment: " << target_adxl_engine_name
528+ << " cost: " << count << " us" ;
529+ return ;
507530 }
508531 int ret = checkAndConnect (target_adxl_engine_name);
509532 if (ret != 0 ) {
@@ -543,89 +566,95 @@ void AscendDirectTransport::processSliceList(
543566
544567void AscendDirectTransport::localCopy (TransferRequest::OpCode opcode,
545568 const std::vector<Slice *> &slice_list) {
546- std::vector<Slice *> async_list ;
547- for ( auto &slice : slice_list) {
548- auto local_ptr = slice-> source_addr ;
549- auto remote_ptr =
550- reinterpret_cast < void *>(slice-> ascend_direct . dest_addr ) ;
551- aclrtPtrAttributes attributes;
552- auto ret = aclrtPointerGetAttributes (slice-> source_addr , &attributes);
553- if (ret != ACL_ERROR_NONE) {
554- LOG (ERROR) << " aclrtPointerGetAttributes failed, ret: " << ret;
569+ aclrtMemcpyKind kind ;
570+ auto &first_slice = slice_list[ 0 ];
571+ auto remote_ptr =
572+ reinterpret_cast < void *>(first_slice-> ascend_direct . dest_addr );
573+ aclrtPtrAttributes attributes ;
574+ auto ret = aclrtPointerGetAttributes (first_slice-> source_addr , & attributes) ;
575+ if ( ret != ACL_ERROR_NONE) {
576+ LOG (ERROR) << " aclrtPointerGetAttributes failed, ret: " << ret;
577+ for ( auto &slice : slice_list) {
555578 slice->markFailed ();
556- continue ;
557579 }
558- aclrtPtrAttributes dst_attributes;
559- ret = aclrtPointerGetAttributes (remote_ptr, &dst_attributes);
560- if (ret != ACL_ERROR_NONE) {
561- LOG (ERROR) << " aclrtPointerGetAttributes failed, ret:" << ret;
580+ }
581+ aclrtPtrAttributes dst_attributes;
582+ ret = aclrtPointerGetAttributes (remote_ptr, &dst_attributes);
583+ if (ret != ACL_ERROR_NONE) {
584+ LOG (ERROR) << " aclrtPointerGetAttributes failed, ret:" << ret;
585+ for (auto &slice : slice_list) {
562586 slice->markFailed ();
563- continue ;
564587 }
565- if (attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
566- attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
567- LOG (ERROR) << " location of local addr is not supported." ;
588+ }
589+ if (attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
590+ attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
591+ LOG (ERROR) << " location of local addr is not supported." ;
592+ for (auto &slice : slice_list) {
568593 slice->markFailed ();
569- continue ;
570594 }
571- if (dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
572- dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
573- LOG (ERROR) << " location of remote addr is not supported." ;
595+ }
596+ if (dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_HOST &&
597+ dst_attributes.location .type != ACL_MEM_LOCATION_TYPE_DEVICE) {
598+ LOG (ERROR) << " location of remote addr is not supported." ;
599+ for (auto &slice : slice_list) {
574600 slice->markFailed ();
575- continue ;
576- }
577- aclrtMemcpyKind kind;
578- auto len = slice->length ;
579- if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST &&
580- dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
581- ret = aclrtMemcpy (remote_ptr, len, local_ptr, len,
582- ACL_MEMCPY_HOST_TO_HOST);
583- if (ret == ACL_ERROR_NONE) {
584- slice->markSuccess ();
585- } else {
586- LOG (ERROR) << " aclrtMemcpyAsync failed, ret:" << ret;
587- slice->markFailed ();
588- }
589- continue ;
590- } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE &&
591- dst_attributes.location .type ==
592- ACL_MEM_LOCATION_TYPE_DEVICE) {
593- kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
594- } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
595- kind = (opcode == TransferRequest::WRITE)
596- ? ACL_MEMCPY_HOST_TO_DEVICE
597- : ACL_MEMCPY_DEVICE_TO_HOST;
598- } else {
599- kind = (opcode == TransferRequest::WRITE)
600- ? ACL_MEMCPY_DEVICE_TO_HOST
601- : ACL_MEMCPY_HOST_TO_DEVICE;
602601 }
603- if (opcode == TransferRequest::WRITE) {
604- ret = aclrtMemcpyAsync (remote_ptr, len, local_ptr, len, kind,
605- stream_);
602+ }
603+ if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST &&
604+ dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
605+ kind = ACL_MEMCPY_HOST_TO_HOST;
606+ } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE &&
607+ dst_attributes.location .type == ACL_MEM_LOCATION_TYPE_DEVICE) {
608+ kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
609+ } else if (attributes.location .type == ACL_MEM_LOCATION_TYPE_HOST) {
610+ kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE
611+ : ACL_MEMCPY_DEVICE_TO_HOST;
612+ } else {
613+ kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST
614+ : ACL_MEMCPY_HOST_TO_DEVICE;
615+ }
616+ std::vector<void *> void_remote_addrs (slice_list.size ());
617+ std::vector<void *> void_local_addrs (slice_list.size ());
618+ std::vector<aclrtMemcpyBatchAttr> attrs (slice_list.size ());
619+ std::vector<size_t > attrsIds (slice_list.size ());
620+ std::vector<size_t > sizes (slice_list.size ());
621+ size_t idx = 0 ;
622+ for (size_t i = 0 ; i < slice_list.size (); i++) {
623+ auto device_loc = aclrtMemLocation{
624+ static_cast <uint32_t >(device_id_),
625+ aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE};
626+ auto host_loc = aclrtMemLocation{
627+ 0 , aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST};
628+ if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
629+ attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}};
606630 } else {
607- ret = aclrtMemcpyAsync (local_ptr, len, remote_ptr, len, kind,
608- stream_);
609- }
610- if (ret != ACL_ERROR_NONE) {
611- LOG (ERROR) << " aclrtMemcpyAsync failed, ret:" << ret;
612- slice->markFailed ();
613- continue ;
631+ attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}};
614632 }
615- async_list.emplace_back (slice);
633+ attrsIds[i] = idx++;
634+ auto &slice = slice_list[i];
635+ void_local_addrs[i] = slice->source_addr ;
636+ void_remote_addrs[i] =
637+ reinterpret_cast <void *>(slice->ascend_direct .dest_addr );
638+ sizes[i] = slice->length ;
639+ }
640+ size_t fail_idx;
641+ if (opcode == TransferRequest::WRITE) {
642+ ret = aclrtMemcpyBatch (void_remote_addrs.data (), sizes.data (),
643+ void_local_addrs.data (), sizes.data (),
644+ sizes.size (), attrs.data (), attrsIds.data (),
645+ attrs.size (), &fail_idx);
646+ } else {
647+ ret = aclrtMemcpyBatch (void_local_addrs.data (), sizes.data (),
648+ void_remote_addrs.data (), sizes.data (),
649+ sizes.size (), attrs.data (), attrsIds.data (),
650+ attrs.size (), &fail_idx);
616651 }
617- auto ret = aclrtSynchronizeStreamWithTimeout (stream_, transfer_timeout_);
618652 if (ret == ACL_ERROR_NONE) {
619- for (auto &slice : async_list ) {
653+ for (auto &slice : slice_list ) {
620654 slice->markSuccess ();
621655 }
622656 } else {
623- LOG (ERROR) << " Memory copy timeout." ;
624- ret = aclrtStreamAbort (stream_);
625- if (ret != ACL_ERROR_NONE) {
626- LOG (ERROR) << " Failed to abort stream, ret:" << ret;
627- }
628- for (auto &slice : async_list) {
657+ for (auto &slice : slice_list) {
629658 slice->markFailed ();
630659 }
631660 }
@@ -636,8 +665,8 @@ int AscendDirectTransport::checkAndConnect(
636665 std::lock_guard<std::mutex> lock (connection_mutex_);
637666 auto it = connected_segments_.find (target_adxl_engine_name);
638667 if (it != connected_segments_.end ()) {
639- LOG (INFO ) << " Already connected to target adxl engine: "
640- << target_adxl_engine_name;
668+ VLOG ( 1 ) << " Already connected to target adxl engine: "
669+ << target_adxl_engine_name;
641670 return 0 ;
642671 }
643672 auto status =
0 commit comments