Skip to content

Commit 931e51e

Browse files
author
youxiao
committed
add thread pool for ascend direct transport for multiple destination.
1 parent 69cf6ea commit 931e51e

File tree

5 files changed

+123
-91
lines changed

5 files changed

+123
-91
lines changed

mooncake-store/src/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ set(MOONCAKE_STORE_SOURCES
1111
utils.cpp
1212
master_metric_manager.cpp
1313
storage_backend.cpp
14-
thread_pool.cpp
1514
etcd_helper.cpp
1615
ha_helper.cpp
1716
segment.cpp

mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <acl/acl.h>
3232
#include "transfer_metadata.h"
3333
#include "transport/transport.h"
34+
#include "thread_pool.h"
3435
#include "adxl/adxl_engine.h"
3536

3637
namespace mooncake {
@@ -114,6 +115,9 @@ class AscendDirectTransport : public Transport {
114115
std::string local_adxl_engine_name_{};
115116
aclrtStream stream_{};
116117
bool use_buffer_pool_{false};
118+
119+
ThreadPool thread_pool_;
120+
int32_t device_id_{-1};
117121
};
118122

119123
} // namespace mooncake
File renamed without changes.

mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp

Lines changed: 119 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
#include "transport/transport.h"
3434

3535
namespace mooncake {
36-
AscendDirectTransport::AscendDirectTransport() : running_(false) {}
36+
AscendDirectTransport::AscendDirectTransport()
37+
: running_(false), thread_pool_(4) {}
3738

3839
AscendDirectTransport::~AscendDirectTransport() {
3940
LOG(INFO) << "AscendDirectTransport destructor called";
@@ -420,6 +421,11 @@ int AscendDirectTransport::allocateLocalSegmentID() {
420421
LOG(ERROR) << "Call aclrtGetCurrentContext failed, ret: " << ret;
421422
return ret;
422423
}
424+
ret = aclrtGetDevice(&device_id_);
425+
if (ret) {
426+
LOG(ERROR) << "Call aclrtGetDevice failed, ret: " << ret;
427+
return ret;
428+
}
423429
desc->rank_info.hostIp = host_ip;
424430
int sockfd;
425431
desc->rank_info.hostPort = findAvailableTcpPort(sockfd);
@@ -447,25 +453,35 @@ void AscendDirectTransport::workerThread() {
447453
return;
448454
}
449455
while (running_) {
450-
std::unique_lock<std::mutex> lock(queue_mutex_);
451-
queue_cv_.wait(lock,
452-
[this] { return !running_ || !slice_queue_.empty(); });
453-
if (!running_) {
454-
break;
455-
}
456-
457-
if (!slice_queue_.empty()) {
458-
auto slice_list = std::move(slice_queue_.front());
459-
slice_queue_.pop();
460-
lock.unlock();
461-
462-
if (slice_list.empty()) {
463-
LOG(ERROR)
464-
<< "AscendDirectTransport: empty transfer request batch";
465-
continue;
456+
std::vector<Slice *> slice_list;
457+
{
458+
std::unique_lock<std::mutex> lock(queue_mutex_);
459+
queue_cv_.wait(
460+
lock, [this] { return !running_ || !slice_queue_.empty(); });
461+
if (!running_) {
462+
break;
466463
}
467-
468-
processSliceList(slice_list);
464+
slice_list = std::move(slice_queue_.front());
465+
slice_queue_.pop();
466+
}
467+
if (slice_list.empty()) {
468+
LOG(ERROR) << "AscendDirectTransport: empty transfer request batch";
469+
continue;
470+
}
471+
std::unordered_map<SegmentID, std::vector<Slice *>> seg_to_slices;
472+
for (auto slice : slice_list) {
473+
seg_to_slices[slice->target_id].push_back(slice);
474+
}
475+
for (auto &[seg_id, slices] : seg_to_slices) {
476+
thread_pool_.enqueue([this, slices] {
477+
auto ret = aclrtSetCurrentContext(rt_context_);
478+
if (ret) {
479+
LOG(ERROR)
480+
<< "Call aclrtSetCurrentContext failed, ret: " << ret;
481+
return;
482+
}
483+
processSliceList(slices);
484+
});
469485
}
470486
}
471487
LOG(INFO) << "AscendDirectTransport worker thread stopped";
@@ -503,7 +519,14 @@ void AscendDirectTransport::processSliceList(
503519
}
504520
if (target_adxl_engine_name == local_adxl_engine_name_) {
505521
VLOG(1) << "Target is local, use memory copy.";
506-
return localCopy(slice_list[0]->opcode, slice_list);
522+
auto start = std::chrono::steady_clock::now();
523+
localCopy(slice_list[0]->opcode, slice_list);
524+
uint64_t count = std::chrono::duration_cast<std::chrono::microseconds>(
525+
std::chrono::steady_clock::now() - start)
526+
.count();
527+
LOG(INFO) << "Copy to local segment: " << target_adxl_engine_name
528+
<< " cost: " << count << " us";
529+
return;
507530
}
508531
int ret = checkAndConnect(target_adxl_engine_name);
509532
if (ret != 0) {
@@ -543,89 +566,95 @@ void AscendDirectTransport::processSliceList(
543566

544567
void AscendDirectTransport::localCopy(TransferRequest::OpCode opcode,
545568
const std::vector<Slice *> &slice_list) {
546-
std::vector<Slice *> async_list;
547-
for (auto &slice : slice_list) {
548-
auto local_ptr = slice->source_addr;
549-
auto remote_ptr =
550-
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
551-
aclrtPtrAttributes attributes;
552-
auto ret = aclrtPointerGetAttributes(slice->source_addr, &attributes);
553-
if (ret != ACL_ERROR_NONE) {
554-
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
569+
aclrtMemcpyKind kind;
570+
auto &first_slice = slice_list[0];
571+
auto remote_ptr =
572+
reinterpret_cast<void *>(first_slice->ascend_direct.dest_addr);
573+
aclrtPtrAttributes attributes;
574+
auto ret = aclrtPointerGetAttributes(first_slice->source_addr, &attributes);
575+
if (ret != ACL_ERROR_NONE) {
576+
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
577+
for (auto &slice : slice_list) {
555578
slice->markFailed();
556-
continue;
557579
}
558-
aclrtPtrAttributes dst_attributes;
559-
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
560-
if (ret != ACL_ERROR_NONE) {
561-
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
580+
}
581+
aclrtPtrAttributes dst_attributes;
582+
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
583+
if (ret != ACL_ERROR_NONE) {
584+
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
585+
for (auto &slice : slice_list) {
562586
slice->markFailed();
563-
continue;
564587
}
565-
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
566-
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
567-
LOG(ERROR) << "location of local addr is not supported.";
588+
}
589+
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
590+
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
591+
LOG(ERROR) << "location of local addr is not supported.";
592+
for (auto &slice : slice_list) {
568593
slice->markFailed();
569-
continue;
570594
}
571-
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
572-
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
573-
LOG(ERROR) << "location of remote addr is not supported.";
595+
}
596+
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
597+
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
598+
LOG(ERROR) << "location of remote addr is not supported.";
599+
for (auto &slice : slice_list) {
574600
slice->markFailed();
575-
continue;
576-
}
577-
aclrtMemcpyKind kind;
578-
auto len = slice->length;
579-
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
580-
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
581-
ret = aclrtMemcpy(remote_ptr, len, local_ptr, len,
582-
ACL_MEMCPY_HOST_TO_HOST);
583-
if (ret == ACL_ERROR_NONE) {
584-
slice->markSuccess();
585-
} else {
586-
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
587-
slice->markFailed();
588-
}
589-
continue;
590-
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
591-
dst_attributes.location.type ==
592-
ACL_MEM_LOCATION_TYPE_DEVICE) {
593-
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
594-
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
595-
kind = (opcode == TransferRequest::WRITE)
596-
? ACL_MEMCPY_HOST_TO_DEVICE
597-
: ACL_MEMCPY_DEVICE_TO_HOST;
598-
} else {
599-
kind = (opcode == TransferRequest::WRITE)
600-
? ACL_MEMCPY_DEVICE_TO_HOST
601-
: ACL_MEMCPY_HOST_TO_DEVICE;
602601
}
603-
if (opcode == TransferRequest::WRITE) {
604-
ret = aclrtMemcpyAsync(remote_ptr, len, local_ptr, len, kind,
605-
stream_);
602+
}
603+
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
604+
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
605+
kind = ACL_MEMCPY_HOST_TO_HOST;
606+
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
607+
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE) {
608+
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
609+
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
610+
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE
611+
: ACL_MEMCPY_DEVICE_TO_HOST;
612+
} else {
613+
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST
614+
: ACL_MEMCPY_HOST_TO_DEVICE;
615+
}
616+
std::vector<void *> void_remote_addrs(slice_list.size());
617+
std::vector<void *> void_local_addrs(slice_list.size());
618+
std::vector<aclrtMemcpyBatchAttr> attrs(slice_list.size());
619+
std::vector<size_t> attrsIds(slice_list.size());
620+
std::vector<size_t> sizes(slice_list.size());
621+
size_t idx = 0;
622+
for (size_t i = 0; i < slice_list.size(); i++) {
623+
auto device_loc = aclrtMemLocation{
624+
static_cast<uint32_t>(device_id_),
625+
aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE};
626+
auto host_loc = aclrtMemLocation{
627+
0, aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST};
628+
if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
629+
attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}};
606630
} else {
607-
ret = aclrtMemcpyAsync(local_ptr, len, remote_ptr, len, kind,
608-
stream_);
609-
}
610-
if (ret != ACL_ERROR_NONE) {
611-
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
612-
slice->markFailed();
613-
continue;
631+
attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}};
614632
}
615-
async_list.emplace_back(slice);
633+
attrsIds[i] = idx++;
634+
auto &slice = slice_list[i];
635+
void_local_addrs[i] = slice->source_addr;
636+
void_remote_addrs[i] =
637+
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
638+
sizes[i] = slice->length;
639+
}
640+
size_t fail_idx;
641+
if (opcode == TransferRequest::WRITE) {
642+
ret = aclrtMemcpyBatch(void_remote_addrs.data(), sizes.data(),
643+
void_local_addrs.data(), sizes.data(),
644+
sizes.size(), attrs.data(), attrsIds.data(),
645+
attrs.size(), &fail_idx);
646+
} else {
647+
ret = aclrtMemcpyBatch(void_local_addrs.data(), sizes.data(),
648+
void_remote_addrs.data(), sizes.data(),
649+
sizes.size(), attrs.data(), attrsIds.data(),
650+
attrs.size(), &fail_idx);
616651
}
617-
auto ret = aclrtSynchronizeStreamWithTimeout(stream_, transfer_timeout_);
618652
if (ret == ACL_ERROR_NONE) {
619-
for (auto &slice : async_list) {
653+
for (auto &slice : slice_list) {
620654
slice->markSuccess();
621655
}
622656
} else {
623-
LOG(ERROR) << "Memory copy timeout.";
624-
ret = aclrtStreamAbort(stream_);
625-
if (ret != ACL_ERROR_NONE) {
626-
LOG(ERROR) << "Failed to abort stream, ret:" << ret;
627-
}
628-
for (auto &slice : async_list) {
657+
for (auto &slice : slice_list) {
629658
slice->markFailed();
630659
}
631660
}
@@ -636,8 +665,8 @@ int AscendDirectTransport::checkAndConnect(
636665
std::lock_guard<std::mutex> lock(connection_mutex_);
637666
auto it = connected_segments_.find(target_adxl_engine_name);
638667
if (it != connected_segments_.end()) {
639-
LOG(INFO) << "Already connected to target adxl engine: "
640-
<< target_adxl_engine_name;
668+
VLOG(1) << "Already connected to target adxl engine: "
669+
<< target_adxl_engine_name;
641670
return 0;
642671
}
643672
auto status =

0 commit comments

Comments
 (0)