|
4 | 4 | * @copyright (c) 2023, DeepLink.
|
5 | 5 | */
|
6 | 6 |
|
7 |
| -#include "../common/acloprunner.hpp" |
| 7 | +#include "../aclnn/acl_scalar.hpp" |
| 8 | +#include "../aclnn/adaptor.hpp" |
8 | 9 |
|
9 | 10 | namespace impl {
|
10 | 11 | namespace ascend {
|
11 | 12 |
|
12 | 13 | diopiError_t diopiBaddbmm(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t batch1,
|
13 | 14 | diopiConstTensorHandle_t batch2, double beta, double alpha) {
|
14 |
| - diopiDtype_t outDtype; |
15 |
| - diopiGetTensorDtype(out, &outDtype); |
| 15 | + AscendTensor inAt(input); |
| 16 | + auto betas = constructDiopiScalarT(inAt.dtype(), beta); |
| 17 | + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); |
16 | 18 |
|
17 |
| - AscendTensor inputAt(input); |
18 |
| - AscendTensor outputAt(out); |
19 |
| - AscendTensor batch1At(batch1); |
20 |
| - AscendTensor batch2At(batch2); |
21 |
| - |
22 |
| - // get the size of batch1 * batch2 |
23 |
| - std::vector<int64_t> batch1Shape = batch1At.shape(); |
24 |
| - std::vector<int64_t> batch2Shape = batch2At.shape(); |
25 |
| - std::vector<int64_t> vectorSizeBatchMatMulTensor = {batch1Shape[0], batch1Shape[1], batch2Shape[2]}; |
26 |
| - |
27 |
| - // init a tensor according to the size of batch1 * batch2 ; |
28 |
| - diopiSize_t diopiSizeBatchMatMulTensor = vectorToDiopiSize(vectorSizeBatchMatMulTensor); |
29 |
| - AscendTensor batchMatMulTensorAt; |
30 |
| - makeTensor(ctx, batchMatMulTensorAt, &diopiSizeBatchMatMulTensor, outDtype, diopiDevice_t::diopi_device); |
31 |
| - |
32 |
| - // does batch1/batch2 need to transpose? |
33 |
| - bool isSelfT = false; |
34 |
| - bool isMat2T = false; |
35 |
| - |
36 |
| - // do batch1 times batch2 -> BatchMatMulTensor |
37 |
| - AclOpRunner<2, 1>("BatchMatMul", ctx) |
38 |
| - .addInput(batch1At) |
39 |
| - .addInput(batch2At) |
40 |
| - .addOutput(batchMatMulTensorAt) |
41 |
| - .setAttr("adj_x1", isSelfT) |
42 |
| - .setAttr("adj_x2", isMat2T) |
43 |
| - .run(); |
44 |
| - |
45 |
| - // init memory based on the size of alphaMulTensor and betaMulTensor |
46 |
| - AscendTensor alphaMulTensor; |
47 |
| - AscendTensor betaMulTensor; |
48 |
| - makeTensorLike(ctx, alphaMulTensor, batchMatMulTensorAt, outDtype); |
49 |
| - makeTensorLike(ctx, betaMulTensor, inputAt, outDtype); |
50 |
| - |
51 |
| - diopiScalar_t alphaScalar = constructDiopiScalarT(outDtype, alpha); |
52 |
| - diopiScalar_t betaScalar = constructDiopiScalarT(outDtype, beta); |
53 |
| - |
54 |
| - // transform ascendTensor to diopiTensorHandle_t |
55 |
| - diopiTensorHandle_t diopiAlphaMulTensor = const_cast<diopiTensorHandle_t>(alphaMulTensor.tensorHandle()); |
56 |
| - diopiTensorHandle_t diopiBateMulTensor = const_cast<diopiTensorHandle_t>(betaMulTensor.tensorHandle()); |
57 |
| - diopiTensorHandle_t diopiAsBatchMatMulTensor = const_cast<diopiTensorHandle_t>(batchMatMulTensorAt.tensorHandle()); |
58 |
| - diopiTensorHandle_t diopiInput = const_cast<diopiTensorHandle_t>(inputAt.tensorHandle()); |
59 |
| - |
60 |
| - // alpha times BatchMatMulTensor -> alphaMulTensor and beta times input -> betaMulTensor |
61 |
| - diopiMulScalar(ctx, diopiAlphaMulTensor, diopiAsBatchMatMulTensor, &alphaScalar); |
62 |
| - diopiMulScalar(ctx, diopiBateMulTensor, diopiInput, &betaScalar); |
63 |
| - |
64 |
| - diopiScalar_t otherScalar = constructDiopiScalarT(outDtype, 1); |
65 |
| - diopiTensorHandle_t diopiOutput = const_cast<diopiTensorHandle_t>(outputAt.tensorHandle()); |
66 |
| - diopiAdd(ctx, diopiOutput, diopiAlphaMulTensor, diopiBateMulTensor, &otherScalar); |
| 19 | + int cubeMathType = 0; |
| 20 | + DIOPI_ASCEND_CALL_ACLNN(aclnnBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, out, cubeMathType); |
67 | 21 | return diopiSuccess;
|
68 | 22 | }
|
69 | 23 |
|
70 | 24 | diopiError_t diopiBaddbmmInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t batch1, diopiConstTensorHandle_t batch2, double beta,
|
71 | 25 | double alpha) {
|
72 |
| - return diopiBaddbmm(ctx, input, input, batch1, batch2, beta, alpha); |
| 26 | + AscendTensor inAt(input); |
| 27 | + auto betas = constructDiopiScalarT(inAt.dtype(), beta); |
| 28 | + auto alphas = constructDiopiScalarT(inAt.dtype(), alpha); |
| 29 | + |
| 30 | + int cubeMathType = 0; |
| 31 | + DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceBaddbmm, ctx, input, batch1, batch2, &betas, &alphas, cubeMathType); |
| 32 | + return diopiSuccess; |
73 | 33 | }
|
74 | 34 |
|
75 | 35 | } // namespace ascend
|
|
0 commit comments