Skip to content

Commit 794b873

Browse files
authored
[Ascend] fuj/replace-baddbmm-and-fix-max-and-min-config (#1208)
* replace baddbmm and fix max and min config * replace bmm with aclnn * support bfloat16
1 parent a45c104 commit 794b873

File tree

5 files changed

+41
-73
lines changed

5 files changed

+41
-73
lines changed

impl/ascend/ascend_tensor.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept {
9595
return acl_dtype;
9696

9797
switch (dtype) {
98+
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bfloat16, ACL_BF16)
9899
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float16, ACL_FLOAT16)
99100
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float32, ACL_FLOAT)
100101
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_float64, ACL_DOUBLE)
@@ -107,6 +108,7 @@ constexpr aclDataType diopiDtypeToAclDataType(diopiDtype_t dtype) noexcept {
107108
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_int64, ACL_INT64)
108109
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_uint64, ACL_UINT64)
109110
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_bool, ACL_BOOL)
111+
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex32, ACL_COMPLEX32)
110112
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex64, ACL_COMPLEX64)
111113
DIOPI_DTYPE_TO_ACL_DTYPE_CASE(diopi_dtype_complex128, ACL_COMPLEX128)
112114
default:

impl/ascend/device_configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@
3535
]
3636
)
3737
),
38+
# Bad in-place call: input tensor size [2] and output tensor size [2, 0, 2] should match
39+
# pytorch 2.1.0 does not support this case
40+
# input: (2,), batch1: (2, 0, 4), batch2: (2, 4, 2)
41+
'baddbmm_without_inplace': dict(
42+
name=["baddbmm"],
43+
tensor_para=dict(
44+
args=[
45+
{
46+
"ins": ["input"],
47+
"shape": [Skip((2,))],
48+
},
49+
],
50+
),
51+
),
3852

3953
# temp for 910B
4054
'uniform': dict(

impl/ascend/functions/baddbmm.cpp

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,72 +4,32 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include "../common/acloprunner.hpp"
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
89

910
namespace impl {
1011
namespace ascend {
1112

1213
diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t batch1,
1314
diopiConstTensorHandle_t batch2, double beta, double alpha) {
14-
diopiDtype_t outDtype;
15-
diopiGetTensorDtype(out, &outDtype);
15+
AscendTensor inAt(input);
16+
auto betas = constructDiopiScalarT(inAt.dtype(), beta);
17+
auto alphas = constructDiopiScalarT(inAt.dtype(), alpha);
1618

17-
AscendTensor inputAt(input);
18-
AscendTensor outputAt(out);
19-
AscendTensor batch1At(batch1);
20-
AscendTensor batch2At(batch2);
21-
22-
// get the size of batch1 * batch2
23-
std::vector<int64_t> batch1Shape = batch1At.shape();
24-
std::vector<int64_t> batch2Shape = batch2At.shape();
25-
std::vector<int64_t> vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]};
26-
27-
// init a tensor according to the size of batch1 * batch2 ;
28-
diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor);
29-
AscendTensor batchMatMulTensorAt;
30-
makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device);
31-
32-
// does batch1/batch2 need to transpose?
33-
bool isSelfT = false;
34-
bool isMat2T = false;
35-
36-
// do batch1 times batch2 -> BatchMatMulTensor
37-
AclOpRunner<2, 1>("BatchMatMul", ctx)
38-
.addInput(batch1At)
39-
.addInput(batch2At)
40-
.addOutput(batchMatMulTensorAt)
41-
.setAttr("adj_x1", isSelfT)
42-
.setAttr("adj_x2", isMat2T)
43-
.run();
44-
45-
// init memory based on the size of alphaMulTensor and betaMulTensor
46-
AscendTensor alphaMulTensor;
47-
AscendTensor betaMulTensor;
48-
makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype);
49-
makeTensorLike(ctx, betaMulTensor, inputAt, outDtype);
50-
51-
diopiScalar_t alphaScalar = constructDiopiScalarT(outDtype, alpha);
52-
diopiScalar_t betaScalar = constructDiopiScalarT(outDtype, beta);
53-
54-
// transform ascendTensor to diopiTensorHandle_t
55-
diopiTensorHandle_t diopiAlphaMulTensor = const_cast<diopiTensorHandle_t>(alphaMulTensor.tensorHandle());
56-
diopiTensorHandle_t diopiBateMulTensor = const_cast<diopiTensorHandle_t>(betaMulTensor.tensorHandle());
57-
diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(batchMatMulTensorAt.tensorHandle());
58-
diopiTensorHandle_t diopiInput = const_cast<diopiTensorHandle_t>(inputAt.tensorHandle());
59-
60-
// alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor
61-
diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar);
62-
diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar);
63-
64-
diopiScalar_t otherScalar = constructDiopiScalarT(outDtype, 1);
65-
diopiTensorHandle_t diopiOutput = const_cast<diopiTensorHandle_t>(outputAt.tensorHandle());
66-
diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &otherScalar);
19+
int cubeMathType = 0;
20+
DIOPI_ASCEND_CALL_ACLNN(aclnnBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, out, cubeMathType);
6721
return diopiSuccess;
6822
}
6923

7024
diopiError_t diopiBaddbmmInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta,
7125
double alpha) {
72-
return diopiBaddbmm(ctx, input, input, batch1, batch2, beta, alpha);
26+
AscendTensor inAt(input);
27+
auto betas = constructDiopiScalarT(inAt.dtype(), beta);
28+
auto alphas = constructDiopiScalarT(inAt.dtype(), alpha);
29+
30+
int cubeMathType = 0;
31+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, cubeMathType);
32+
return diopiSuccess;
7333
}
7434

7535
} // namespace ascend

impl/ascend/functions/bmm.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,20 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include "../common/acloprunner.hpp"
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
89

910
namespace impl {
1011
namespace ascend {
1112

1213
diopiError_t diopiBmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t mat2) {
13-
AscendTensor inputAt(input);
14-
AscendTensor mat2At(mat2);
15-
AscendTensor outputAt(out);
16-
if (inputAt.numel() == 0 || mat2At.numel() == 0) {
17-
diopiScalar_t zero = constructDiopiScalarT(outputAt.dtype(), 0.0);
18-
diopiFill(ctx, out, &zero);
19-
return diopiSuccess;
20-
}
14+
AscendTensor inAt(input);
15+
AscendTensor matAt(mat2);
16+
ASCEND_CHECK_ABORT(inAt.dtype() == matAt.dtype(), "[diopiBmm] tensors dtype does not matched.");
17+
18+
int cubeMathType = 0;
19+
DIOPI_ASCEND_CALL_ACLNN(aclnnBatchMatMul, ctx, input, mat2, out, cubeMathType);
2120

22-
AclOpRunner<2, 1>("BatchMatMulV2", ctx).addInput(input).addInput(mat2).setAttr("adj_x1", false).setAttr("adj_x1", false).addOutput(out).run();
2321
return diopiSuccess;
2422
}
2523

impl/ascend_npu/ascend_config.yaml

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ ascend:
1313
- diopiArgmax
1414
- diopiAtan
1515
- diopiAtanInp
16+
- diopiBaddbmm
17+
- diopiBaddbmmInp
1618
- diopiBitwiseNot
1719
- diopiBitwiseNotInp
1820
- diopiBitwiseAnd
@@ -23,6 +25,7 @@ ascend:
2325
- diopiBitwiseOrInp
2426
- diopiBitwiseOrScalar
2527
- diopiBitwiseOrInpScalar
28+
- diopiBmm
2629
- diopiCastDtype
2730
- diopiClamp
2831
- diopiClampInp
@@ -201,26 +204,17 @@ ascend_npu:
201204
- diopiAddcmul
202205
- diopiAddcmulInp
203206
- diopiAddmm
204-
- diopiBaddbmm
205-
- diopiBaddbmmInp
206207
- diopiBatchNorm
207208
- diopiBatchNormBackward
208209
- diopiNonzero
209-
- diopiBmm
210210
- diopiMatmul
211-
- diopiMaxAll
212-
- diopiMin
213-
- diopiMinAll
214-
- diopiMinimum
215211
- diopiCat
216212
- diopiDropout
217213
- diopiDropoutInp
218214
- diopiCopyInp
219215
- diopiExpand
220216
- diopiGroupNorm
221217
- diopiGroupNormBackward
222-
- diopiMax
223-
- diopiMaximum
224218
- diopiMaskedFill
225219
- diopiMaskedFillInp
226220
- diopiMaskedFillInpScalar

0 commit comments

Comments
 (0)