Skip to content

Commit 836918f

Browse files
author
youxiao
committed
ascend direct transport transfer to multiple destinations
1 parent b66deff commit 836918f

File tree

1 file changed

+104
-89
lines changed

1 file changed

+104
-89
lines changed

mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp

Lines changed: 104 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -447,25 +447,27 @@ void AscendDirectTransport::workerThread() {
447447
return;
448448
}
449449
while (running_) {
450-
std::unique_lock<std::mutex> lock(queue_mutex_);
451-
queue_cv_.wait(lock,
452-
[this] { return !running_ || !slice_queue_.empty(); });
453-
if (!running_) {
454-
break;
455-
}
456-
457-
if (!slice_queue_.empty()) {
458-
auto slice_list = std::move(slice_queue_.front());
459-
slice_queue_.pop();
460-
lock.unlock();
461-
462-
if (slice_list.empty()) {
463-
LOG(ERROR)
464-
<< "AscendDirectTransport: empty transfer request batch";
465-
continue;
450+
std::vector<Slice *> slice_list;
451+
{
452+
std::unique_lock<std::mutex> lock(queue_mutex_);
453+
queue_cv_.wait(
454+
lock, [this] { return !running_ || !slice_queue_.empty(); });
455+
if (!running_) {
456+
break;
466457
}
467-
468-
processSliceList(slice_list);
458+
slice_list = std::move(slice_queue_.front());
459+
slice_queue_.pop();
460+
}
461+
if (slice_list.empty()) {
462+
LOG(ERROR) << "AscendDirectTransport: empty transfer request batch";
463+
continue;
464+
}
465+
std::unordered_map<SegmentID, std::vector<Slice *>> seg_to_slices;
466+
for (auto slice : slice_list) {
467+
seg_to_slices[slice->target_id].push_back(slice);
468+
}
469+
for (auto &[seg_id, slices] : seg_to_slices) {
470+
processSliceList(slices);
469471
}
470472
}
471473
LOG(INFO) << "AscendDirectTransport worker thread stopped";
@@ -503,7 +505,14 @@ void AscendDirectTransport::processSliceList(
503505
}
504506
if (target_adxl_engine_name == local_adxl_engine_name_) {
505507
VLOG(1) << "Target is local, use memory copy.";
506-
return localCopy(slice_list[0]->opcode, slice_list);
508+
auto start = std::chrono::steady_clock::now();
509+
localCopy(slice_list[0]->opcode, slice_list);
510+
uint64_t count = std::chrono::duration_cast<std::chrono::microseconds>(
511+
std::chrono::steady_clock::now() - start)
512+
.count();
513+
LOG(INFO) << "Copy to local segment: " << target_adxl_engine_name
514+
<< " cost: " << count << " us";
515+
return;
507516
}
508517
int ret = checkAndConnect(target_adxl_engine_name);
509518
if (ret != 0) {
@@ -543,89 +552,95 @@ void AscendDirectTransport::processSliceList(
543552

544553
void AscendDirectTransport::localCopy(TransferRequest::OpCode opcode,
545554
const std::vector<Slice *> &slice_list) {
546-
std::vector<Slice *> async_list;
547-
for (auto &slice : slice_list) {
548-
auto local_ptr = slice->source_addr;
549-
auto remote_ptr =
550-
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
551-
aclrtPtrAttributes attributes;
552-
auto ret = aclrtPointerGetAttributes(slice->source_addr, &attributes);
553-
if (ret != ACL_ERROR_NONE) {
554-
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
555+
aclrtMemcpyKind kind;
556+
auto &first_slice = slice_list[0];
557+
auto remote_ptr =
558+
reinterpret_cast<void *>(first_slice->ascend_direct.dest_addr);
559+
aclrtPtrAttributes attributes;
560+
auto ret = aclrtPointerGetAttributes(first_slice->source_addr, &attributes);
561+
if (ret != ACL_ERROR_NONE) {
562+
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
563+
for (auto &slice : slice_list) {
555564
slice->markFailed();
556-
continue;
557565
}
558-
aclrtPtrAttributes dst_attributes;
559-
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
560-
if (ret != ACL_ERROR_NONE) {
561-
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
566+
}
567+
aclrtPtrAttributes dst_attributes;
568+
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
569+
if (ret != ACL_ERROR_NONE) {
570+
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
571+
for (auto &slice : slice_list) {
562572
slice->markFailed();
563-
continue;
564573
}
565-
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
566-
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
567-
LOG(ERROR) << "location of local addr is not supported.";
574+
}
575+
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
576+
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
577+
LOG(ERROR) << "location of local addr is not supported.";
578+
for (auto &slice : slice_list) {
568579
slice->markFailed();
569-
continue;
570580
}
571-
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
572-
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
573-
LOG(ERROR) << "location of remote addr is not supported.";
581+
}
582+
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
583+
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
584+
LOG(ERROR) << "location of remote addr is not supported.";
585+
for (auto &slice : slice_list) {
574586
slice->markFailed();
575-
continue;
576-
}
577-
aclrtMemcpyKind kind;
578-
auto len = slice->length;
579-
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
580-
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
581-
ret = aclrtMemcpy(remote_ptr, len, local_ptr, len,
582-
ACL_MEMCPY_HOST_TO_HOST);
583-
if (ret == ACL_ERROR_NONE) {
584-
slice->markSuccess();
585-
} else {
586-
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
587-
slice->markFailed();
588-
}
589-
continue;
590-
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
591-
dst_attributes.location.type ==
592-
ACL_MEM_LOCATION_TYPE_DEVICE) {
593-
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
594-
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
595-
kind = (opcode == TransferRequest::WRITE)
596-
? ACL_MEMCPY_HOST_TO_DEVICE
597-
: ACL_MEMCPY_DEVICE_TO_HOST;
598-
} else {
599-
kind = (opcode == TransferRequest::WRITE)
600-
? ACL_MEMCPY_DEVICE_TO_HOST
601-
: ACL_MEMCPY_HOST_TO_DEVICE;
602587
}
603-
if (opcode == TransferRequest::WRITE) {
604-
ret = aclrtMemcpyAsync(remote_ptr, len, local_ptr, len, kind,
605-
stream_);
588+
}
589+
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
590+
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
591+
kind = ACL_MEMCPY_HOST_TO_HOST;
592+
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
593+
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE) {
594+
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
595+
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
596+
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE
597+
: ACL_MEMCPY_DEVICE_TO_HOST;
598+
} else {
599+
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST
600+
: ACL_MEMCPY_HOST_TO_DEVICE;
601+
}
602+
std::vector<void *> void_remote_addrs(slice_list.size());
603+
std::vector<void *> void_local_addrs(slice_list.size());
604+
std::vector<aclrtMemcpyBatchAttr> attrs(slice_list.size());
605+
std::vector<size_t> attrsIds(slice_list.size());
606+
std::vector<size_t> sizes(slice_list.size());
607+
size_t idx = 0;
608+
for (size_t i = 0; i < slice_list.size(); i++) {
609+
auto device_loc = aclrtMemLocation{
610+
static_cast<uint32_t>(device_logic_id_),
611+
aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE};
612+
auto host_loc = aclrtMemLocation{
613+
0, aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST};
614+
if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
615+
attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}};
606616
} else {
607-
ret = aclrtMemcpyAsync(local_ptr, len, remote_ptr, len, kind,
608-
stream_);
609-
}
610-
if (ret != ACL_ERROR_NONE) {
611-
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
612-
slice->markFailed();
613-
continue;
617+
attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}};
614618
}
615-
async_list.emplace_back(slice);
619+
attrsIds[i] = idx++;
620+
auto &slice = slice_list[i];
621+
void_local_addrs[i] = slice->source_addr;
622+
void_remote_addrs[i] =
623+
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
624+
sizes[i] = slice->length;
625+
}
626+
size_t fail_idx;
627+
if (opcode == TransferRequest::WRITE) {
628+
ret = aclrtMemcpyBatch(void_remote_addrs.data(), sizes.data(),
629+
void_local_addrs.data(), sizes.data(),
630+
sizes.size(), attrs.data(), attrsIds.data(),
631+
attrs.size(), &fail_idx);
632+
} else {
633+
ret = aclrtMemcpyBatch(void_local_addrs.data(), sizes.data(),
634+
void_remote_addrs.data(), sizes.data(),
635+
sizes.size(), attrs.data(), attrsIds.data(),
636+
attrs.size(), &fail_idx);
616637
}
617-
auto ret = aclrtSynchronizeStreamWithTimeout(stream_, transfer_timeout_);
618638
if (ret == ACL_ERROR_NONE) {
619-
for (auto &slice : async_list) {
639+
for (auto &slice : slice_list) {
620640
slice->markSuccess();
621641
}
622642
} else {
623-
LOG(ERROR) << "Memory copy timeout.";
624-
ret = aclrtStreamAbort(stream_);
625-
if (ret != ACL_ERROR_NONE) {
626-
LOG(ERROR) << "Failed to abort stream, ret:" << ret;
627-
}
628-
for (auto &slice : async_list) {
643+
for (auto &slice : slice_list) {
629644
slice->markFailed();
630645
}
631646
}
@@ -636,8 +651,8 @@ int AscendDirectTransport::checkAndConnect(
636651
std::lock_guard<std::mutex> lock(connection_mutex_);
637652
auto it = connected_segments_.find(target_adxl_engine_name);
638653
if (it != connected_segments_.end()) {
639-
LOG(INFO) << "Already connected to target adxl engine: "
640-
<< target_adxl_engine_name;
654+
VLOG(1) << "Already connected to target adxl engine: "
655+
<< target_adxl_engine_name;
641656
return 0;
642657
}
643658
auto status =

0 commit comments

Comments
 (0)