Skip to content

Commit a0b6398

Browse files
committed
impl nll_loss with aclnn for ascend
1 parent 523eede commit a0b6398

File tree

3 files changed

+107
-6
lines changed

3 files changed

+107
-6
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/**
2+
* @file
3+
* @author DeepLink
4+
* @copyright (c) 2024, DeepLink.
5+
*/
6+
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
9+
10+
namespace impl {
11+
namespace ascend {
12+
diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, diopiConstTensorHandle_t input,
13+
diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) {
14+
if (input == nullptr) {
15+
return diopiSuccess;
16+
}
17+
18+
AscendTensor inputAt(input);
19+
if (inputAt.numel() <= 0) {
20+
if (diopiReduction_t::ReductionMean == reduction) {
21+
DIOPI_ASCEND_CALL_ACLNN(aclnnInpalceFillScalar, ctx, out, std::nanf(""));
22+
} else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) {
23+
DIOPI_ASCEND_CALL_ACLNN(aclnnInpalceZero, ctx, out);
24+
}
25+
return diopiSuccess;
26+
}
27+
28+
diopiTensorHandle_t weightTmp = const_cast<diopiTensorHandle_t>(weight);
29+
if (weightTmp == nullptr) {
30+
const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1);
31+
std::vector<int64_t> weightSize{channel};
32+
diopiSize_t weightShape = vectorToDiopiSize(weightSize);
33+
diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device);
34+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp);
35+
}
36+
37+
if (inputAt.dim() <= 2) {
38+
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight);
39+
} else if (inputAt.dim() == 4) {
40+
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, input, target, weightTmp, reduction, ignoreIndex, out, totalWeight);
41+
} else {
42+
AscendTensor outAt(out);
43+
AscendTensor targetAt(target);
44+
AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1});
45+
AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt;
46+
AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1});
47+
}
48+
49+
return diopiSuccess;
50+
}
51+
52+
diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput,
53+
diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight,
54+
diopiConstTensorHandle_t totalWeight, diopiReduction_t reduction, int64_t ignoreIndex) {
55+
AscendTensor inputAt(input);
56+
AscendTensor gradInputAt(gradInput);
57+
if (input == nullptr || gradInput == nullptr || inputAt.numel() <= 0 || gradInputAt.numel() <= 0) {
58+
return diopiSuccess;
59+
}
60+
/*
61+
* A tensor representing the sum of weights for each element considered in the NLL loss computation.
62+
* In case a weight tensor is provided, total_weight represents the sum of weights for all the non-ignored indices in the target tensor.
63+
* When no weight tensor is provided, total_weight corresponds to the count of all non-ignored indices.
64+
*/
65+
diopiTensorHandle_t weightTmp = const_cast<diopiTensorHandle_t>(weight);
66+
if (weightTmp == nullptr) {
67+
const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1);
68+
std::vector<int64_t> weightSize{channel};
69+
diopiSize_t weightShape = vectorToDiopiSize(weightSize);
70+
diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device);
71+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp);
72+
}
73+
74+
if (inputAt.dim() <= 2) {
75+
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLossBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput);
76+
} else if (inputAt.dim() == 4) {
77+
DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2dBackward, ctx, gradOutput, input, target, weightTmp, reduction, ignoreIndex, totalWeight, gradInput);
78+
} else {
79+
AscendTensor gradIputAt(gradInput);
80+
AscendTensor gradOutputAt(gradOutput);
81+
AscendTensor targetAt(target);
82+
83+
AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1});
84+
AscendTensor gradInputView =
85+
gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1});
86+
AscendTensor gradOutputView;
87+
if (gradOutputAt.numel() > 1) {
88+
gradOutputView.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1});
89+
} else {
90+
gradOutputView = gradOutputAt;
91+
}
92+
AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1});
93+
DIOPI_ASCEND_CALL_ACLNN(
94+
aclnnNLLLoss2dBackward, ctx, gradOutputView, inputView, targetView, weightTmp, reduction, ignoreIndex, totalWeight, gradInputView);
95+
}
96+
return diopiSuccess;
97+
}
98+
99+
} // namespace ascend
100+
} // namespace impl

impl/ascend_npu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ set(OLD_IMPL_SRC
197197
${OLD_IMPL_DIR}/functions/zeros.cpp
198198
${OLD_IMPL_DIR}/functions/matmul.cpp
199199
${OLD_IMPL_DIR}/functions/equal.cpp
200+
${OLD_IMPL_DIR}/functions/nlllossv2.cpp
200201
${OLD_IMPL_DIR}/functions_mmcv/roi_align_npu.cpp
201202
${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp
202203
#${OLD_IMPL_DIR}/test/export_functions.cpp

impl/ascend_npu/ascend_config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ ascend:
1616
- diopiArgmax
1717
- diopiAtan
1818
- diopiAtanInp
19-
- diopiAttention
20-
- diopiAttentionBackward
21-
- diopiAttentionVarLen
22-
- diopiAttentionVarLenBackward
2319
- diopiBaddbmm
2420
- diopiBaddbmmInp
2521
- diopiBitwiseNot
@@ -148,6 +144,8 @@ ascend:
148144
- diopiNeInp
149145
- diopiNeInpScalar
150146
- diopiNeScalar
147+
- diopiNLLLossV2
148+
- diopiNLLLossV2Backward
151149
- diopiNorm
152150
- diopiNormal
153151
- diopiNormalInp
@@ -214,6 +212,10 @@ ascend_npu:
214212
- diopiAdamW
215213
- diopiAdaptiveAvgPool2d
216214
- diopiAdaptiveAvgPool2dBackward
215+
- diopiAttention
216+
- diopiAttentionBackward
217+
- diopiAttentionVarLen
218+
- diopiAttentionVarLenBackward
217219
- diopiBatchNorm
218220
- diopiBatchNormBackward
219221
- diopiNonzero
@@ -248,8 +250,6 @@ ascend_npu:
248250
- diopiMm
249251
- diopiNLLLoss
250252
- diopiNLLLossBackward
251-
- diopiNLLLossV2
252-
- diopiNLLLossV2Backward
253253
- diopiScatter
254254
- diopiScatterInp
255255
- diopiScatterScalar

0 commit comments

Comments
 (0)