Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -447,25 +447,27 @@ void AscendDirectTransport::workerThread() {
return;
}
while (running_) {
std::unique_lock<std::mutex> lock(queue_mutex_);
queue_cv_.wait(lock,
[this] { return !running_ || !slice_queue_.empty(); });
if (!running_) {
break;
}

if (!slice_queue_.empty()) {
auto slice_list = std::move(slice_queue_.front());
slice_queue_.pop();
lock.unlock();

if (slice_list.empty()) {
LOG(ERROR)
<< "AscendDirectTransport: empty transfer request batch";
continue;
std::vector<Slice *> slice_list;
{
std::unique_lock<std::mutex> lock(queue_mutex_);
queue_cv_.wait(
lock, [this] { return !running_ || !slice_queue_.empty(); });
if (!running_) {
break;
}

processSliceList(slice_list);
slice_list = std::move(slice_queue_.front());
slice_queue_.pop();
}
if (slice_list.empty()) {
LOG(ERROR) << "AscendDirectTransport: empty transfer request batch";
continue;
}
std::unordered_map<SegmentID, std::vector<Slice *>> seg_to_slices;
for (auto slice : slice_list) {
seg_to_slices[slice->target_id].push_back(slice);
}
for (auto &[seg_id, slices] : seg_to_slices) {
processSliceList(slices);
}
}
LOG(INFO) << "AscendDirectTransport worker thread stopped";
Expand Down Expand Up @@ -503,7 +505,14 @@ void AscendDirectTransport::processSliceList(
}
if (target_adxl_engine_name == local_adxl_engine_name_) {
VLOG(1) << "Target is local, use memory copy.";
return localCopy(slice_list[0]->opcode, slice_list);
auto start = std::chrono::steady_clock::now();
localCopy(slice_list[0]->opcode, slice_list);
uint64_t count = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - start)
.count();
LOG(INFO) << "Copy to local segment: " << target_adxl_engine_name
<< " cost: " << count << " us";
return;
}
int ret = checkAndConnect(target_adxl_engine_name);
if (ret != 0) {
Expand Down Expand Up @@ -543,89 +552,95 @@ void AscendDirectTransport::processSliceList(

void AscendDirectTransport::localCopy(TransferRequest::OpCode opcode,
const std::vector<Slice *> &slice_list) {
std::vector<Slice *> async_list;
for (auto &slice : slice_list) {
auto local_ptr = slice->source_addr;
auto remote_ptr =
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
aclrtPtrAttributes attributes;
auto ret = aclrtPointerGetAttributes(slice->source_addr, &attributes);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
aclrtMemcpyKind kind;
auto &first_slice = slice_list[0];
auto remote_ptr =
reinterpret_cast<void *>(first_slice->ascend_direct.dest_addr);
aclrtPtrAttributes attributes;
auto ret = aclrtPointerGetAttributes(first_slice->source_addr, &attributes);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
for (auto &slice : slice_list) {
slice->markFailed();
continue;
}
aclrtPtrAttributes dst_attributes;
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
}
aclrtPtrAttributes dst_attributes;
ret = aclrtPointerGetAttributes(remote_ptr, &dst_attributes);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "aclrtPointerGetAttributes failed, ret:" << ret;
for (auto &slice : slice_list) {
slice->markFailed();
continue;
}
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
LOG(ERROR) << "location of local addr is not supported.";
}
if (attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
LOG(ERROR) << "location of local addr is not supported.";
for (auto &slice : slice_list) {
slice->markFailed();
continue;
}
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
LOG(ERROR) << "location of remote addr is not supported.";
}
if (dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_HOST &&
dst_attributes.location.type != ACL_MEM_LOCATION_TYPE_DEVICE) {
LOG(ERROR) << "location of remote addr is not supported.";
for (auto &slice : slice_list) {
slice->markFailed();
continue;
}
aclrtMemcpyKind kind;
auto len = slice->length;
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
ret = aclrtMemcpy(remote_ptr, len, local_ptr, len,
ACL_MEMCPY_HOST_TO_HOST);
if (ret == ACL_ERROR_NONE) {
slice->markSuccess();
} else {
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
slice->markFailed();
}
continue;
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
dst_attributes.location.type ==
ACL_MEM_LOCATION_TYPE_DEVICE) {
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
kind = (opcode == TransferRequest::WRITE)
? ACL_MEMCPY_HOST_TO_DEVICE
: ACL_MEMCPY_DEVICE_TO_HOST;
} else {
kind = (opcode == TransferRequest::WRITE)
? ACL_MEMCPY_DEVICE_TO_HOST
: ACL_MEMCPY_HOST_TO_DEVICE;
}
if (opcode == TransferRequest::WRITE) {
ret = aclrtMemcpyAsync(remote_ptr, len, local_ptr, len, kind,
stream_);
}
if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST &&
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
kind = ACL_MEMCPY_HOST_TO_HOST;
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE &&
dst_attributes.location.type == ACL_MEM_LOCATION_TYPE_DEVICE) {
kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
} else if (attributes.location.type == ACL_MEM_LOCATION_TYPE_HOST) {
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_HOST_TO_DEVICE
: ACL_MEMCPY_DEVICE_TO_HOST;
} else {
kind = (opcode == TransferRequest::WRITE) ? ACL_MEMCPY_DEVICE_TO_HOST
: ACL_MEMCPY_HOST_TO_DEVICE;
}
std::vector<void *> void_remote_addrs(slice_list.size());
std::vector<void *> void_local_addrs(slice_list.size());
std::vector<aclrtMemcpyBatchAttr> attrs(slice_list.size());
std::vector<size_t> attrsIds(slice_list.size());
std::vector<size_t> sizes(slice_list.size());
size_t idx = 0;
for (size_t i = 0; i < slice_list.size(); i++) {
auto device_loc = aclrtMemLocation{
static_cast<uint32_t>(device_logic_id_),
aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_DEVICE};
auto host_loc = aclrtMemLocation{
0, aclrtMemLocationType::ACL_MEM_LOCATION_TYPE_HOST};
if (kind == ACL_MEMCPY_DEVICE_TO_HOST) {
attrs[i] = aclrtMemcpyBatchAttr{host_loc, device_loc, {}};
} else {
ret = aclrtMemcpyAsync(local_ptr, len, remote_ptr, len, kind,
stream_);
}
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "aclrtMemcpyAsync failed, ret:" << ret;
slice->markFailed();
continue;
attrs[i] = aclrtMemcpyBatchAttr{device_loc, host_loc, {}};
}
async_list.emplace_back(slice);
attrsIds[i] = idx++;
auto &slice = slice_list[i];
void_local_addrs[i] = slice->source_addr;
void_remote_addrs[i] =
reinterpret_cast<void *>(slice->ascend_direct.dest_addr);
sizes[i] = slice->length;
}
size_t fail_idx;
if (opcode == TransferRequest::WRITE) {
ret = aclrtMemcpyBatch(void_remote_addrs.data(), sizes.data(),
void_local_addrs.data(), sizes.data(),
sizes.size(), attrs.data(), attrsIds.data(),
attrs.size(), &fail_idx);
} else {
ret = aclrtMemcpyBatch(void_local_addrs.data(), sizes.data(),
void_remote_addrs.data(), sizes.data(),
sizes.size(), attrs.data(), attrsIds.data(),
attrs.size(), &fail_idx);
}
auto ret = aclrtSynchronizeStreamWithTimeout(stream_, transfer_timeout_);
if (ret == ACL_ERROR_NONE) {
for (auto &slice : async_list) {
for (auto &slice : slice_list) {
slice->markSuccess();
}
} else {
LOG(ERROR) << "Memory copy timeout.";
ret = aclrtStreamAbort(stream_);
if (ret != ACL_ERROR_NONE) {
LOG(ERROR) << "Failed to abort stream, ret:" << ret;
}
for (auto &slice : async_list) {
for (auto &slice : slice_list) {
slice->markFailed();
}
}
Expand All @@ -636,8 +651,8 @@ int AscendDirectTransport::checkAndConnect(
std::lock_guard<std::mutex> lock(connection_mutex_);
auto it = connected_segments_.find(target_adxl_engine_name);
if (it != connected_segments_.end()) {
LOG(INFO) << "Already connected to target adxl engine: "
<< target_adxl_engine_name;
VLOG(1) << "Already connected to target adxl engine: "
<< target_adxl_engine_name;
return 0;
}
auto status =
Expand Down
Loading