diff --git a/mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h b/mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h index c0fc6778b..329c18cfb 100644 --- a/mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h +++ b/mooncake-transfer-engine/include/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.h @@ -84,6 +84,16 @@ class AscendDirectTransport : public Transport { void localCopy(TransferRequest::OpCode opcode, const std::vector &slice_list); + void copyWithSync(TransferRequest::OpCode opcode, + const std::vector &slice_list, + aclrtMemcpyKind kind); + + void copyWithAsync(TransferRequest::OpCode opcode, + const std::vector &slice_list, + aclrtMemcpyKind kind); + + uint16_t findAdxlListenPort() const; + private: int InitAdxlEngine(); @@ -114,6 +124,8 @@ class AscendDirectTransport : public Transport { std::string local_adxl_engine_name_{}; aclrtStream stream_{}; bool use_buffer_pool_{false}; + + int32_t base_port_ = 20000; }; } // namespace mooncake diff --git a/mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp b/mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp index 1136e7a13..7fc8b5636 100644 --- a/mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp +++ b/mooncake-transfer-engine/src/transport/ascend_transport/ascend_direct_transport/ascend_direct_transport.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include "common.h" #include "transfer_engine.h" @@ -151,6 +152,8 @@ int AscendDirectTransport::InitAdxlEngine() { LOG(INFO) << "Set RdmaServiceLevel to:" << rdma_sl; } } + // set default buffer pool + options["adxl.BufferPool"] = "4:8"; char *buffer_pool = std::getenv("ASCEND_BUFFER_POOL"); if (buffer_pool) { options["adxl.BufferPool"] = buffer_pool; @@ -421,13 +424,11 @@ int AscendDirectTransport::allocateLocalSegmentID() { return ret; } desc->rank_info.hostIp = host_ip; - int sockfd; - desc->rank_info.hostPort = findAvailableTcpPort(sockfd); + desc->rank_info.hostPort = findAdxlListenPort(); if (desc->rank_info.hostPort == 0) { LOG(ERROR) << "Find available port failed."; return FAILED; } - close(sockfd); local_adxl_engine_name_ = host_ip + ":" + std::to_string(desc->rank_info.hostPort); @@ -439,6 +440,67 @@ int AscendDirectTransport::allocateLocalSegmentID() { return 0; } +uint16_t AscendDirectTransport::findAdxlListenPort() const { + int32_t dev_id = device_logic_id_; + char *rt_visible_devices = std::getenv("ASCEND_RT_VISIBLE_DEVICES"); + if (rt_visible_devices) { + std::vector device_list; + std::stringstream ss(rt_visible_devices); + std::string item; + while (std::getline(ss, item, ',')) { + device_list.push_back(item); + } + if (dev_id < static_cast(device_list.size())) { + try { + dev_id = std::stoi(device_list[dev_id]); + } catch (const std::exception &e) { + LOG(WARNING) << "ASCEND_RT_VISIBLE_DEVICES is not valid, value:" + << rt_visible_devices; + } + } else { + LOG(WARNING) << "Device id is " << dev_id + << ", ASCEND_RT_VISIBLE_DEVICES is " + << rt_visible_devices << ", which is unexpected."; + } + } + static std::random_device rand_gen; + std::uniform_int_distribution rand_dist; + const int min_port = base_port_ + dev_id * 1000; + const int max_port = base_port_ + (dev_id + 1) * 1000; + LOG(INFO) << "Find available between " << min_port << " and " << max_port; + const int max_attempts = 500; + int sockfd; + for (int attempt = 0; attempt < max_attempts; ++attempt) { + int port = min_port + rand_dist(rand_gen) % (max_port - min_port + 1); + sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd == -1) { + continue; + } + struct timeval timeout; + timeout.tv_sec = 1; + timeout.tv_usec = 0; + if (setsockopt(sockfd, SOL_SOCKET, SO_RCVTIMEO, &timeout, + sizeof(timeout))) { + close(sockfd); + sockfd = -1; + continue; + } + sockaddr_in bind_address; + memset(&bind_address, 0, sizeof(sockaddr_in)); + bind_address.sin_family = AF_INET; + bind_address.sin_port = htons(port); + bind_address.sin_addr.s_addr = INADDR_ANY; + if (bind(sockfd, (sockaddr *)&bind_address, sizeof(sockaddr_in)) < 0) { + close(sockfd); + sockfd = -1; + continue; + } + close(sockfd); + return port; + } + return 0; +} + void AscendDirectTransport::workerThread() { LOG(INFO) << "AscendDirectTransport worker thread started"; auto ret = aclrtSetCurrentContext(rt_context_); @@ -447,25 +509,27 @@ void AscendDirectTransport::workerThread() { return; } while (running_) { - std::unique_lock lock(queue_mutex_); - queue_cv_.wait(lock, - [this] { return !running_ || !slice_queue_.empty(); }); - if (!running_) { - break; - } - - if (!slice_queue_.empty()) { - auto slice_list = std::move(slice_queue_.front()); - slice_queue_.pop(); - lock.unlock(); - - if (slice_list.empty()) { - LOG(ERROR) - << "AscendDirectTransport: empty transfer request batch"; - continue; + std::vector slice_list; + { + std::unique_lock lock(queue_mutex_); + queue_cv_.wait( + lock, [this] { return !running_ || !slice_queue_.empty(); }); + if (!running_) { + break; } - - processSliceList(slice_list); + slice_list = std::move(slice_queue_.front()); + slice_queue_.pop(); + } + if (slice_list.empty()) { + LOG(ERROR) << "AscendDirectTransport: empty transfer request batch"; + continue; + } + std::unordered_map> seg_to_slices; + for (auto slice : slice_list) { + seg_to_slices[slice->target_id].push_back(slice); + } + for (auto &[seg_id, slices] : seg_to_slices) { + processSliceList(slices); } } LOG(INFO) << "AscendDirectTransport worker thread stopped"; @@ -543,63 +607,143 @@ void AscendDirectTransport::processSliceList( void AscendDirectTransport::localCopy(TransferRequest::OpCode opcode, const std::vector &slice_list) { - std::vector async_list; - for (auto &slice : slice_list) { - auto local_ptr = slice->source_addr; - auto remote_ptr = - reinterpret_cast(slice->ascend_direct.dest_addr); - aclrtPtrAttributes attributes; - auto ret = aclrtPointerGetAttributes(slice->source_addr, &attributes); - if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret; + aclrtMemcpyKind kind; + auto &first_slice = slice_list[0]; + auto remote_ptr = + reinterpret_cast(first_slice->ascend_direct.dest_addr); + aclrtPtrAttributes attributes; + auto ret = aclrtPointerGetAttributes(first_slice->source_addr, &attributes); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret; + for (auto &slice : slice_list) { slice->markFailed(); - continue; } - aclrtPtrAttributes dst_attributes; - ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes); - if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret; + } + aclrtPtrAttributes dst_attributes; + ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes); + if (ret != ACL_ERROR_NONE) { + LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret; + for (auto &slice : slice_list) { slice->markFailed(); - continue; } - if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST && - attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) { - LOG(ERROR) << "location of local addr is not supported."; + } + if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST && + attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) { + LOG(ERROR) << "location of local addr is not supported."; + for (auto &slice : slice_list) { slice->markFailed(); - continue; } - if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST && - dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) { - LOG(ERROR) << "location of remote addr is not supported."; + } + if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST && + dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) { + LOG(ERROR) << "location of remote addr is not supported."; + for (auto &slice : slice_list) { slice->markFailed(); - continue; } - aclrtMemcpyKind kind; + } + if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST && + dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) { + kind = ACL_MEMCPY_HOST_TO_HOST; + } else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE && + dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE) { + kind = ACL_MEMCPY_DEVICE_TO_DEVICE; + } else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) { + kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE + : ACL_MEMCPY_DEVICE_TO_HOST; + } else { + kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST + : ACL_MEMCPY_HOST_TO_DEVICE; + } + if (kind == ACL_MEMCPY_HOST_TO_HOST) { + return copyWithSync(opcode, slice_list, kind); + } + if (kind == ACL_MEMCPY_DEVICE_TO_DEVICE) { + return copyWithAsync(opcode, slice_list, kind); + } + std::vector void_remote_addrs(slice_list.size()); + std::vector void_local_addrs(slice_list.size()); + std::vector attrs(slice_list.size()); + std::vector attrsIds(slice_list.size()); + std::vector sizes(slice_list.size()); + size_t idx = 0; + for (size_t i = 0; i < slice_list.size(); i++) { + auto device_loc = aclrtMemLocation{ + static_cast(device_logic_id_), + aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE}; + auto host_loc = aclrtMemLocation{ + 0, aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST}; + if (kind == ACL_MEMCPY_DEVICE_TO_HOST) { + attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}}; + } else { + attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}}; + } + attrsIds[i] = idx++; + auto &slice = slice_list[i]; + void_local_addrs[i] = slice->source_addr; + void_remote_addrs[i] = + reinterpret_cast(slice->ascend_direct.dest_addr); + sizes[i] = slice->length; + } + size_t fail_idx; + if (opcode == TransferRequest::WRITE) { + ret = aclrtMemcpyBatch(void_remote_addrs.data(), sizes.data(), + void_local_addrs.data(), sizes.data(), + sizes.size(), attrs.data(), attrsIds.data(), + attrs.size(), &fail_idx); + } else { + ret = aclrtMemcpyBatch(void_local_addrs.data(), sizes.data(), + void_remote_addrs.data(), sizes.data(), + sizes.size(), attrs.data(), attrsIds.data(), + attrs.size(), &fail_idx); + } + if (ret == ACL_ERROR_RT_FEATURE_NOT_SUPPORT) { + return copyWithAsync(opcode, slice_list, kind); + } + if (ret == ACL_ERROR_NONE) { + VLOG(1) << "Copy with aclrtMemcpyBatch suc."; + for (auto &slice : slice_list) { + slice->markSuccess(); + } + } else { + for (auto &slice : slice_list) { + slice->markFailed(); + } + } +} + +void AscendDirectTransport::copyWithSync(TransferRequest::OpCode opcode, + const std::vector &slice_list, + aclrtMemcpyKind kind) { + for (auto &slice : slice_list) { + auto local_ptr = slice->source_addr; + auto remote_ptr = + reinterpret_cast(slice->ascend_direct.dest_addr); auto len = slice->length; - if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST && - dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) { - ret = aclrtMemcpy(remote_ptr, len, local_ptr, len, - ACL_MEMCPY_HOST_TO_HOST); - if (ret == ACL_ERROR_NONE) { - slice->markSuccess(); - } else { - LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret; - slice->markFailed(); - } - continue; - } else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE && - dst_attributes.location.type == - ACL_MEM_LOCATION_TYPE_DEVICE) { - kind = ACL_MEMCPY_DEVICE_TO_DEVICE; - } else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) { - kind = (opcode == TransferRequest::WRITE) - ? ACL_MEMCPY_HOST_TO_DEVICE - : ACL_MEMCPY_DEVICE_TO_HOST; + aclError ret; + if (opcode == TransferRequest::WRITE) { + ret = aclrtMemcpy(remote_ptr, len, local_ptr, len, kind); + } else { + ret = aclrtMemcpy(local_ptr, len, remote_ptr, len, kind); + } + if (ret == ACL_ERROR_NONE) { + VLOG(1) << "Copy with aclrtMemcpy suc."; + slice->markSuccess(); } else { - kind = (opcode == TransferRequest::WRITE) - ? ACL_MEMCPY_DEVICE_TO_HOST - : ACL_MEMCPY_HOST_TO_DEVICE; + LOG(ERROR) << "aclrtMemcpy failed, ret:" << ret; + slice->markFailed(); } + } +} +void AscendDirectTransport::copyWithAsync( + TransferRequest::OpCode opcode, const std::vector &slice_list, + aclrtMemcpyKind kind) { + std::vector async_list; + aclError ret; + for (auto &slice : slice_list) { + auto local_ptr = slice->source_addr; + auto remote_ptr = + reinterpret_cast(slice->ascend_direct.dest_addr); + auto len = slice->length; if (opcode == TransferRequest::WRITE) { ret = aclrtMemcpyAsync(remote_ptr, len, local_ptr, len, kind, stream_); @@ -614,17 +758,15 @@ void AscendDirectTransport::localCopy(TransferRequest::OpCode opcode, } async_list.emplace_back(slice); } - auto ret = aclrtSynchronizeStreamWithTimeout(stream_, transfer_timeout_); + ret = aclrtSynchronizeStreamWithTimeout(stream_, transfer_timeout_); if (ret == ACL_ERROR_NONE) { + VLOG(1) << "Copy with aclrtMemcpyAsync suc."; for (auto &slice : async_list) { slice->markSuccess(); } } else { - LOG(ERROR) << "Memory copy timeout."; - ret = aclrtStreamAbort(stream_); - if (ret != ACL_ERROR_NONE) { - LOG(ERROR) << "Failed to abort stream, ret:" << ret; - } + LOG(ERROR) << "Memory copy failed."; + (void)aclrtStreamAbort(stream_); for (auto &slice : async_list) { slice->markFailed(); } @@ -636,8 +778,8 @@ int AscendDirectTransport::checkAndConnect( std::lock_guard lock(connection_mutex_); auto it = connected_segments_.find(target_adxl_engine_name); if (it != connected_segments_.end()) { - LOG(INFO) << "Already connected to target adxl engine: " - << target_adxl_engine_name; + VLOG(1) << "Already connected to target adxl engine: " + << target_adxl_engine_name; return 0; } auto status =