diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 737e62f14..d8f56164e 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -947,6 +947,8 @@ def cross_entropy( weight = weight else: weight = None + + totalWeight = Tensor((1,), input.get_dtype()) sizeI = input.size().data sizeO = [sizeI[0]] + sizeI[2:] @@ -956,7 +958,7 @@ def cross_entropy( out = Tensor((), input.get_dtype()) reduction_mode = convert_reduction(reduction) - func = check_function("diopiCrossEntropyLoss") + func = check_function("diopiCrossEntropyLossWithTotalWeight") ret = func( input.context(), out, @@ -968,6 +970,7 @@ def cross_entropy( label_smoothing, ) check_returncode(ret) + GLOBAL_STATE["cross_entropy_totalWeight"] = totalWeight return out @@ -4640,9 +4643,11 @@ def cross_entropy_backward( weight = weight else: weight = None + + totalWeight = GLOBAL_STATE.pop("cross_entropy_totalWeight") reduction_mode = convert_reduction(reduction) - func = check_function("diopiCrossEntropyLossBackward") + func = check_function("diopiCrossEntropyLossWithTotalWeightBackward") ret = func( input.context(), grad_input, diff --git a/impl/ascend/convert_config.yaml b/impl/ascend/convert_config.yaml index ac320648a..895b12710 100755 --- a/impl/ascend/convert_config.yaml +++ b/impl/ascend/convert_config.yaml @@ -479,3 +479,11 @@ - diopiMaxPool2dBackward: tensor_dtype: indices: (int64)->int32 + +- diopiCrossEntropyLossWithTotalWeight: + dtype: (float64)->float32 + layout: ND + +- diopiCrossEntropyLossWithTotalWeightBackward: + dtype: (float64)->float32 + layout: ND diff --git a/impl/ascend/functions/loss.cpp b/impl/ascend/functions/loss.cpp index 365fcd7b0..53f78b339 100644 --- a/impl/ascend/functions/loss.cpp +++ b/impl/ascend/functions/loss.cpp @@ -4,295 +4,183 @@ * @copyright (c) 2023, DeepLink. */ -#include - -#include "../common/acloprunner.hpp" +#include "../aclnn/acl_scalar.hpp" +#include "../aclnn/adaptor.hpp" namespace impl { namespace ascend { -diopiError_t nllLossOutWithTotalWeight(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, diopiConstTensorHandle_t input, - diopiConstTensorHandle_t target, diopiTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { - AscendTensor inputAt0(input), outAt0(out); +diopiError_t diopiNLLLoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, + diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { + return diopiNoImplement; +} + +diopiError_t diopiNLLLossBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { + return diopiNoImplement; +} - if (0 == inputAt0.numel()) { - // align with pytorch +diopiError_t diopiNLLLossV2(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { + if (input == nullptr) { + return diopiSuccess; + } + + AscendTensor inputAt(input); + if (inputAt.numel() <= 0) { if (diopiReduction_t::ReductionMean == reduction) { - fillNan(ctx, outAt0); + diopiScalar_t nans{diopi_dtype_float64, {std::nanf("")}}; + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceFillScalar, ctx, out, &nans); } else if (diopiReduction_t::ReductionSum == reduction || diopiReduction_t::ReductionNone == reduction) { - fillTensor(ctx, out, 0.0f); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceZero, ctx, out); } return diopiSuccess; } - diopiSize_t inputShape; - diopiTensorHandle_t inputCopy, targetCopy; - diopiGetTensorShape(input, &inputShape); - - diopiDtype_t dtype; - diopiGetTensorDtype(input, &dtype); - - int64_t batch = 1; - if (inputShape.len > 2) { - std::vector inputCopyShapeVec; - std::vector permuteDimVec; - inputCopyShapeVec.push_back(inputShape.data[0]); - permuteDimVec.push_back(0); - for (int i = 1; i < inputShape.len - 1; i++) { - inputCopyShapeVec.push_back(inputShape.data[i + 1]); - permuteDimVec.push_back(i + 1); - } - inputCopyShapeVec.push_back(inputShape.data[1]); - permuteDimVec.push_back(1); - diopiSize_t inputCopyShape = vectorToDiopiSize(inputCopyShapeVec); - diopiSize_t permuteDim = vectorToDiopiSize(permuteDimVec); - - diopiRequireTensor(ctx, &inputCopy, &inputCopyShape, nullptr, dtype, diopi_device); - diopiPermute(ctx, inputCopy, input, permuteDim); - - for (int i = 0; i < inputCopyShapeVec.size() - 1; ++i) - if (inputCopyShapeVec[i] != 0) batch *= inputCopyShapeVec[i]; - } else { - inputCopy = contiguous(ctx, input); + diopiTensorHandle_t weightTmp = const_cast(weight); + if (weightTmp == nullptr) { + const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1); + std::vector weightSize{channel}; + diopiSize_t weightShape = vectorToDiopiSize(weightSize); + diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp); } - targetCopy = contiguous(ctx, target, diopi_dtype_int32); - AscendTensor inputAt(inputCopy), outAt(out), targetAt(targetCopy); - - AclOpRunner<3, 2> runner("NLLLoss", ctx); - AscendTensor weightAt(weight); - - AscendTensor weightAtTmp; - if (0 <= ignoreIndex && ignoreIndex < inputAt.shape(-1)) { - diopiTensorHandle_t weightTmp; - weightTmp = clone(ctx, weight); - weightAtTmp = AscendTensor(weightTmp); - castTensor(ctx, weightAtTmp, inputAt.dtype()); - diopiStreamHandle_t stream; - void* ptr = reinterpret_cast(const_cast(weightAtTmp.data())) + ignoreIndex * weightAtTmp.elemsize(); - if (inputAt.dtype() == diopi_dtype_float16) { - half_float::half val = static_cast(0); - diopiGetStream(ctx, &stream); - aclrtMemcpyAsync(ptr, sizeof(half_float::half), &val, sizeof(half_float::half), ACL_MEMCPY_HOST_TO_DEVICE, stream); - } else { - float val = 0.0f; - diopiGetStream(ctx, &stream); - aclrtMemcpyAsync(ptr, sizeof(float), &val, sizeof(float), ACL_MEMCPY_HOST_TO_DEVICE, stream); - } + if (inputAt.dim() <= 2) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss, ctx, input, target, weightTmp, static_cast(reduction), ignoreIndex, out, totalWeight); + } else if (inputAt.dim() == 4) { + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, input, target, weightTmp, static_cast(reduction), ignoreIndex, out, totalWeight); } else { - weightAtTmp = weightAt; - } - - // ascend only support input tensor with 2D dimension - if (inputShape.len == 1) { - reshape(ctx, inputAt, inputAt, {1, inputShape.data[0]}); - reshape(ctx, targetAt, targetAt, {targetAt.numel()}); - } else if (inputShape.len > 2) { - reshape(ctx, inputAt, inputAt, {batch, inputShape.data[1]}); - reshape(ctx, targetAt, targetAt, {targetAt.numel()}); + AscendTensor outAt(out); + AscendTensor targetAt(target); + AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1}); + AscendTensor outView = (outAt.numel() > 1) ? outAt.view({outAt.shape(0), outAt.numel() / outAt.shape(0), 1}) : outAt; + AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1}); + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2d, ctx, inputView, targetView, weightTmp, static_cast(reduction), ignoreIndex, outView, totalWeight); } - runner.addInput(inputAt).addInput(targetAt).addInput(weightAtTmp).setAttr("ignore_index", ignoreIndex); - - diopiDtype_t outOriginDtype = outAt.dtype(); - if (outOriginDtype != diopi_dtype_float32) { - castTensor(ctx, outAt, diopi_dtype_float32); - } - if (reduction == diopiReduction_t::ReductionMean) { - runner.setAttr("reduction", std::string("mean")); - runner.addOutput(outAt); - } else if (reduction == diopiReduction_t::ReductionSum) { - runner.setAttr("reduction", std::string("sum")); - runner.addOutput(outAt); - } else if (reduction == diopiReduction_t::ReductionNone) { - runner.setAttr("reduction", std::string("none")); - runner.addOutput(outAt); - } - runner.addOutput(totalWeight); - runner.run(); - AscendTensor outOri(out); - castTensor(ctx, outAt, outOri); return diopiSuccess; } -std::string getReductionStr(const diopiReduction_t reduction) { - std::string reductionStr = "none"; - if (diopiReduction_t::ReductionMean == reduction) { - reductionStr = "mean"; - } else if (diopiReduction_t::ReductionSum == reduction) { - reductionStr = "sum"; - } else if (diopiReduction_t::ReductionEND == reduction) { - reductionStr = "end"; - } - return reductionStr; -} - -diopiError_t diopiNLLLoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, - diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { - auto totalWeightSizeVec = std::vector({1}); - auto totalWeightSize = vectorToDiopiSize(totalWeightSizeVec); - diopiTensorHandle_t totalWeight, weightCopy; +diopiError_t diopiNLLLossV2Backward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t totalWeight, diopiReduction_t reduction, int64_t ignoreIndex) { AscendTensor inputAt(input); - diopiRequireTensor(ctx, &totalWeight, &totalWeightSize, nullptr, inputAt.dtype(), diopi_device); - - diopiSize_t inputShape; - diopiGetTensorShape(input, &inputShape); + AscendTensor gradInputAt(gradInput); + if (input == nullptr || gradInput == nullptr || inputAt.numel() <= 0 || gradInputAt.numel() <= 0) { + return diopiSuccess; + } + /* + * A tensor representing the sum of weights for each element considered in the NLL loss computation. + * In case a weight tensor is provided, total_weight represents the sum of weights for all the non-ignored indices in the target tensor. + * When no weight tensor is provided, total_weight corresponds to the count of all non-ignored indices. + */ + diopiTensorHandle_t weightTmp = const_cast(weight); + if (weightTmp == nullptr) { + const int64_t channel = inputAt.dim() >= 4 ? inputAt.shape(1) : inputAt.shape(-1); + std::vector weightSize{channel}; + diopiSize_t weightShape = vectorToDiopiSize(weightSize); + diopiRequireTensor(ctx, &weightTmp, &weightShape, nullptr, inputAt.dtype(), diopi_device); + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceOne, ctx, weightTmp); + } - if (weight) { - weightCopy = contiguous(ctx, weight, inputAt.dtype()); + if (inputAt.dim() <= 2) { + DIOPI_ASCEND_CALL_ACLNN( + aclnnNLLLossBackward, ctx, gradOutput, input, target, weightTmp, static_cast(reduction), ignoreIndex, totalWeight, gradInput); + } else if (inputAt.dim() == 4) { + DIOPI_ASCEND_CALL_ACLNN( + aclnnNLLLoss2dBackward, ctx, gradOutput, input, target, weightTmp, static_cast(reduction), ignoreIndex, totalWeight, gradInput); } else { - // weight shape is (C). C is number of classes - int64_t weightDim[1]; - if (inputShape.len == 1) - weightDim[0] = inputShape.data[0]; - else - weightDim[0] = inputShape.data[1]; - diopiSize_t weightShape = arrayToDiopiSize(weightDim, 1); - diopiRequireTensor(ctx, &weightCopy, &weightShape, nullptr, inputAt.dtype(), diopi_device); - fillTensor(ctx, weightCopy, static_cast(1.0)); + AscendTensor gradIputAt(gradInput); + AscendTensor gradOutputAt(gradOutput); + AscendTensor targetAt(target); + + AscendTensor inputView = inputAt.view({inputAt.shape(0), inputAt.shape(1), inputAt.numel() / inputAt.shape(0) / inputAt.shape(1), 1}); + AscendTensor gradInputView = + gradInputAt.view({gradInputAt.shape(0), gradInputAt.shape(1), gradInputAt.numel() / gradInputAt.shape(0) / gradInputAt.shape(1), 1}); + AscendTensor gradOutputView; + if (gradOutputAt.numel() > 1) { + gradOutputView = gradOutputAt.view({gradOutputAt.shape(0), gradOutputAt.numel() / gradOutputAt.shape(0), 1}); + } else { + gradOutputView = gradOutputAt; + } + AscendTensor targetView = targetAt.view({targetAt.shape(0), targetAt.numel() / targetAt.shape(0), 1}); + DIOPI_ASCEND_CALL_ACLNN(aclnnNLLLoss2dBackward, + ctx, + gradOutputView, + inputView, + targetView, + weightTmp, + static_cast(reduction), + ignoreIndex, + totalWeight, + gradInputView); } - - nllLossOutWithTotalWeight(ctx, out, totalWeight, input, target, weightCopy, reduction, ignoreIndex); return diopiSuccess; } -diopiError_t diopiNLLLossBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, - diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex) { - auto totalWeightSizeVec = std::vector({1}); - auto totalWeightSize = vectorToDiopiSize(totalWeightSizeVec); - diopiTensorHandle_t weightCopy, totalWeight, out, inputCopy, targetCopy, gradInputCopy; - AscendTensor inputAt0(input); - diopiRequireTensor(ctx, &totalWeight, &totalWeightSize, nullptr, inputAt0.dtype(), diopi_device); - makeTensorLike(ctx, &out, gradOutput); - - diopiSize_t inputShape; - diopiGetTensorShape(input, &inputShape); - - if (weight) { - weightCopy = contiguous(ctx, weight, inputAt0.dtype()); - } else { - int64_t weightDim[1]; - if (inputShape.len == 1) - weightDim[0] = inputShape.data[0]; - else - weightDim[0] = inputShape.data[1]; - diopiSize_t weightShape = arrayToDiopiSize(weightDim, 1); - diopiRequireTensor(ctx, &weightCopy, &weightShape, nullptr, inputAt0.dtype(), diopi_device); - fillTensor(ctx, weightCopy, static_cast(1.0)); - } +diopiError_t diopiCrossEntropyLoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, + diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex, double labelSmoothing) { + return diopiNoImplement; +} - nllLossOutWithTotalWeight(ctx, out, totalWeight, input, target, weightCopy, reduction, ignoreIndex); +diopiError_t diopiCrossEntropyLossBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiReduction_t reduction, int64_t ignoreIndex, double labelSmoothing) { + return diopiNoImplement; +} - std::vector calShapeVec; - std::vector calTargetShapeVec; +DIOPI_API diopiError_t diopiCrossEntropyLossWithTotalWeight(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t totalWeight, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiReduction_t reduction, int64_t ignoreIndex, double labelSmoothing) { + diopiTensorHandle_t logInput; + AscendTensor inputAt(input); + std::vector inputShape = inputAt.shape(); + diopiSize_t inputSize = vectorToDiopiSize(inputShape); + diopiRequireTensor(ctx, &logInput, &inputSize, nullptr, inputAt.dtype(), diopi_device); - diopiDtype_t dtype, gradDtype; - diopiGetTensorDtype(input, &dtype); - diopiGetTensorDtype(gradInput, &gradDtype); + DIOPI_ASCEND_CALL_ACLNN(aclnnLogSoftmax, ctx, input, 1, logInput); - if (inputShape.len > 2) { - int64_t calShape0 = inputShape.data[0]; - std::vector inputCopyShapeVec; - std::vector permuteDimVec; - inputCopyShapeVec.push_back(inputShape.data[0]); - permuteDimVec.push_back(0); - for (int i = 1; i < inputShape.len - 1; i++) { - inputCopyShapeVec.push_back(inputShape.data[i + 1]); - permuteDimVec.push_back(i + 1); - calShape0 *= inputShape.data[i + 1]; - } - calShapeVec.push_back(calShape0); - calTargetShapeVec.push_back(calShape0); - inputCopyShapeVec.push_back(inputShape.data[1]); - calShapeVec.push_back(inputShape.data[1]); - permuteDimVec.push_back(1); - diopiSize_t inputCopyShape = vectorToDiopiSize(inputCopyShapeVec); - diopiSize_t permuteDim = vectorToDiopiSize(permuteDimVec); + diopiLogSoftmax(ctx, logInput, input, 1); - diopiRequireTensor(ctx, &inputCopy, &inputCopyShape, nullptr, dtype, diopi_device); - diopiPermute(ctx, inputCopy, input, permuteDim); - diopiRequireTensor(ctx, &gradInputCopy, &inputCopyShape, nullptr, dtype, diopi_device); - } else if (inputShape.len == 2) { - inputCopy = contiguous(ctx, input); - calShapeVec.push_back(inputShape.data[0]); - calShapeVec.push_back(inputShape.data[1]); - calTargetShapeVec.push_back(inputShape.data[0]); - } else { // inpusShape.len == 1 - inputCopy = contiguous(ctx, input); - calShapeVec.push_back(1); - calShapeVec.push_back(inputShape.data[0]); - calTargetShapeVec.push_back(1); + if (labelSmoothing > 0.0) { + return diopiNoImplement; + } else { + target = hostToDevice(ctx, target); + diopiNLLLossV2(ctx, out, totalWeight, logInput, target, weight, reduction, ignoreIndex); } - void *dataPtr, *targetPtr; - targetCopy = contiguous(ctx, target, diopi_dtype_int32); - diopiGetTensorData(inputCopy, &dataPtr); - diopiGetTensorData(targetCopy, &targetPtr); - - AscendTensor inputAt(inputCopy), yGradAt(gradOutput); + return diopiSuccess; +} - AclOpRunner<5, 1> runner("NLLLossGrad", ctx); +DIOPI_API diopiError_t diopiCrossEntropyLossWithTotalWeightBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, + diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t totalWeight, diopiReduction_t reduction, int64_t ignoreIndex, + double labelSmoothing) { + AscendTensor inputAt(input); + AscendTensor gradInputAt(gradInput); - runner.addInput(inputAt.data(), inputAt.getAclMemBufferSize(), calShapeVec, ACL_FORMAT_ND, inputAt.dtype()); + diopiTensorHandle_t logInput; + std::vector inputShape = inputAt.shape(); + diopiSize_t inputSize = vectorToDiopiSize(inputShape); + diopiRequireTensor(ctx, &logInput, &inputSize, nullptr, inputAt.dtype(), diopi_device); - if (reduction == diopiReduction_t::ReductionMean) { - runner.setAttr("reduction", std::string("mean")); - runner.addInput(yGradAt); - } else if (reduction == diopiReduction_t::ReductionSum) { - runner.setAttr("reduction", std::string("sum")); - runner.addInput(yGradAt); - } else if (reduction == diopiReduction_t::ReductionNone) { - runner.setAttr("reduction", std::string("none")); - runner.addInput(yGradAt.data(), yGradAt.getAclMemBufferSize(), calTargetShapeVec, ACL_FORMAT_ND, yGradAt.dtype()); - } + diopiLogSoftmax(ctx, logInput, input, 1); - runner.addInput(targetPtr, getBaseBufferSize(targetCopy), calTargetShapeVec, ACL_FORMAT_ND, diopi_dtype_int32).setAttr("ignore_index", ignoreIndex); + diopiTensorHandle_t logGradInput; + std::vector gradInputShape = gradInputAt.shape(); + diopiSize_t gradInputSize = vectorToDiopiSize(gradInputShape); + diopiRequireTensor(ctx, &logGradInput, &gradInputSize, nullptr, gradInputAt.dtype(), diopi_device); - if (inputShape.len > 2) { - void* gradInputPtr; - diopiGetTensorData(gradInputCopy, &gradInputPtr); - runner.addOutput(gradInputPtr, getBaseBufferSize(gradInputCopy), calShapeVec, ACL_FORMAT_ND, gradDtype); + if (labelSmoothing > 0.0) { + return diopiNoImplement; } else { - void* gradInputPtr; - diopiGetTensorData(gradInput, &gradInputPtr); - runner.addOutput(gradInputPtr, getBaseBufferSize(gradInput), calShapeVec, ACL_FORMAT_ND, gradDtype); + target = hostToDevice(ctx, target); + diopiNLLLossV2Backward(ctx, logGradInput, gradOutput, input, target, weight, totalWeight, reduction, ignoreIndex); + diopiLogSoftmaxBackward(ctx, gradInput, logGradInput, logInput, 1); } - runner.addInput(weightCopy).addInput(totalWeight); - runner.run(); - - if (inputShape.len > 2) { - std::vector permuteDimVec; - permuteDimVec.push_back(0); - permuteDimVec.push_back(inputShape.len - 1); - for (int i = 1; i < inputShape.len - 1; i++) { - permuteDimVec.push_back(i); - } - diopiSize_t permuteDim = vectorToDiopiSize(permuteDimVec); - diopiPermute(ctx, gradInput, gradInputCopy, permuteDim); - } - return diopiSuccess; -} -diopiError_t diopiCrossEntropyLoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, - diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignoreIndex, double labelSmoothing) { - diopiTensorHandle_t logTensor; - makeTensorLike(ctx, &logTensor, input); - diopiLogSoftmax(ctx, logTensor, input, 1); - target = hostToDevice(ctx, target); - diopiNLLLoss(ctx, out, logTensor, target, weight, reduction, ignoreIndex); - return diopiSuccess; -} - -diopiError_t diopiCrossEntropyLossBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiConstTensorHandle_t gradOutput, - diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, - diopiReduction_t reduction, int64_t ignoreIndex, double labelSmoothing) { - diopiTensorHandle_t logTensor, gradLog; - makeTensorLike(ctx, &logTensor, input); - diopiLogSoftmax(ctx, logTensor, input, 1); - makeTensorLike(ctx, &gradLog, gradInput); - target = hostToDevice(ctx, target); - diopiNLLLossBackward(ctx, gradLog, gradOutput, input, target, weight, reduction, ignoreIndex); - diopiLogSoftmaxBackward(ctx, gradInput, gradLog, logTensor, 1); return diopiSuccess; } diff --git a/impl/ascend_npu/ascend_config.yaml b/impl/ascend_npu/ascend_config.yaml index 9dbdec336..1e009206d 100755 --- a/impl/ascend_npu/ascend_config.yaml +++ b/impl/ascend_npu/ascend_config.yaml @@ -63,6 +63,8 @@ ascend: - diopiCosInp - diopiCrossEntropyLoss - diopiCrossEntropyLossBackward +- diopiCrossEntropyLossWithTotalWeight +- diopiCrossEntropyLossWithTotalWeightBackward - diopiCustomizedFlashAttention - diopiCustomizedFlashAttentionBackward - diopiCustomizedFlashAttentionVarLen @@ -182,6 +184,8 @@ ascend: - diopiNeScalar - diopiNeg - diopiNegInp +- diopiNLLLossV2 +- diopiNLLLossV2Backward - diopiNonzero - diopiNorm - diopiNormal @@ -269,8 +273,6 @@ ascend_npu: - diopiIndexBackward - diopiNLLLoss - diopiNLLLossBackward -- diopiNLLLossV2 -- diopiNLLLossV2Backward - diopiNativeMemoryFormatCast - diopiPagedAttention - diopiRotaryEmbeddingV2 diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 349993e6b..d39d2faf2 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -549,6 +549,16 @@ DIOPI_API diopiError_t diopiCrossEntropyLossBackward(diopiContextHandle_t ctx, d diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction, int64_t ignore_index, double label_smoothing); +DIOPI_API diopiError_t diopiCrossEntropyLossWithTotalWeight(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t total_weight, + diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiReduction_t reduction, int64_t ignore_index, double label_smoothing); + +DIOPI_API diopiError_t diopiCrossEntropyLossWithTotalWeightBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, + diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, + diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, + diopiConstTensorHandle_t total_weight, diopiReduction_t reduction, int64_t ignore_index, + double label_smoothing); + /** * @brief Measures the NLL loss between the target and input probabilities. * @param[in] ctx Context environment.