Skip to content

[Ascend] fuj/impl-nllloss-for-ascend #1237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: feat/replace_opplugin_by_aclnn
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion impl/ascend/aclnn/adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ inline aclTensor* createAclTensorFromAscendTensor(const AscendTensor& input) {
input.getAclDataType(),
stride.data(),
input.storageOffset(),
format, // input.getAclDataFormat(), // TODO(lljbash): op_plugin assume non-channel-last, why?
format,
&storageSize,
/*storageDimsNum=*/1,
const_cast<void*>(storagePtr));
Expand Down
102 changes: 102 additions & 0 deletions impl/ascend/functions/nlllossv2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/**
* @file
* @author DeepLink
* @copyright (c) 2024, DeepLink.
*/

#include "../aclnn/acl_scalar.hpp"
#include "../aclnn/adaptor.hpp"

namespace impl {
namespace ascend {
diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, diopiConstTensorHandle_t input,
diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) {
if (input == nullptr) {
return diopiSuccess;
}

AscendTensor inputAt(input);
if (inputAt.numel() <= 0) {
if (diopiReduction_t::ReductionMean == reduction) {
diopiScalar_t nans{diopi_dtype_float64, std::nanf("")};
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceFillScalar, ctx, out, &nans);
} else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) {
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceZero, ctx, out);
}
return diopiSuccess;
}

diopiTensorHandle_t weightTmp = const_cast<diopiTensorHandle_t>(weight);
if (weightTmp == nullptr) {
const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1);
std::vector<int64_t> weightSize{channel};
diopiSize_t weightShape = vectorToDiopiSize(weightSize);
diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device);
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp);
}

if (inputAt.dim() <= 2) {
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight);
} else if (inputAt.dim() == 4) {
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight);
} else {
AscendTensor outAt(out);
AscendTensor targetAt(target);
AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1});
AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt;
AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1});
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, inputView, targetView, weightTmp, reduction, ignoreIndex, outView, totalWeight);
}

return diopiSuccess;
}

diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight,
diopiConstTensorHandle_t totalWeight, diopiReduction_t reduction, int64_t ignoreIndex) {
AscendTensor inputAt(input);
AscendTensor gradInputAt(gradInput);
if (input == nullptr || gradInput == nullptr || inputAt.numel() <= 0 || gradInputAt.numel() <= 0) {
return diopiSuccess;
}
/*
* A tensor representing the sum of weights for each element considered in the NLL loss computation.
* In case a weight tensor is provided, total_weight represents the sum of weights for all the non-ignored indices in the target tensor.
* When no weight tensor is provided, total_weight corresponds to the count of all non-ignored indices.
*/
diopiTensorHandle_t weightTmp = const_cast<diopiTensorHandle_t>(weight);
if (weightTmp == nullptr) {
const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1);
std::vector<int64_t> weightSize{channel};
diopiSize_t weightShape = vectorToDiopiSize(weightSize);
diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device);
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp);
}

if (inputAt.dim() <= 2) {
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLossBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput);
} else if (inputAt.dim() == 4) {
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2dBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput);
} else {
AscendTensor gradIputAt(gradInput);
AscendTensor gradOutputAt(gradOutput);
AscendTensor targetAt(target);

AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1});
AscendTensor gradInputView =
gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1});
AscendTensor gradOutputView;
if (gradOutputAt.numel() > 1) {
gradOutputView = gradOutputAt.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1});
} else {
gradOutputView = gradOutputAt;
}
AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1});
DIOPI_ASCEND_CALL_ACLNN(
aclnnNLLLoss2dBackward, ctx, gradOutputView, inputView, targetView, weightTmp, reduction, ignoreIndex, totalWeight, gradInputView);
}
return diopiSuccess;
}

} // namespace ascend
} // namespace impl
1 change: 1 addition & 0 deletions impl/ascend_npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ set(OLD_IMPL_SRC
${OLD_IMPL_DIR}/functions/matmul.cpp
${OLD_IMPL_DIR}/functions/max_pool2d.cpp
${OLD_IMPL_DIR}/functions/equal.cpp
${OLD_IMPL_DIR}/functions/nlllossv2.cpp
${OLD_IMPL_DIR}/functions_mmcv/roi_align_npu.cpp
${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp
${OLD_IMPL_DIR}/functions_ext/adamw.cpp
Expand Down
4 changes: 2 additions & 2 deletions impl/ascend_npu/ascend_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ ascend:
- diopiNeInp
- diopiNeInpScalar
- diopiNeScalar
- diopiNLLLossV2
- diopiNLLLossV2Backward
- diopiNorm
- diopiNormal
- diopiNormalInp
Expand Down Expand Up @@ -261,8 +263,6 @@ ascend_npu:
- diopiMm
- diopiNLLLoss
- diopiNLLLossBackward
- diopiNLLLossV2
- diopiNLLLossV2Backward
- diopiFlashAttention
- diopiFlashAttentionBackward
- diopiFlashAttentionV2
Expand Down