Skip to content

Commit d820aad

Browse files
[ascend]Zzf/linear (#1231)
* reimpl linear with aclnn --------- Co-authored-by: NeosZhang <[email protected]>
1 parent fe68509 commit d820aad

File tree

3 files changed

+58
-90
lines changed

3 files changed

+58
-90
lines changed

impl/ascend/aclnn/adaptor.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ void callAclnnImpl(diopiContextHandle_t ctx, const Args&... args) {
275275

276276
/* 1. call xxxGetWorkspaceSize function. */
277277
static const auto workspaceSizeFuncAddr = getOpApiFuncAddr(workspaceApi);
278-
ASCEND_CHECK_ABORT(workspaceSizeFuncAddr != nullptr, "[%s] can't get workSpaceName function.", api);
278+
ASCEND_CHECK_THROW(workspaceSizeFuncAddr != nullptr, "[%s] can't get workSpaceName function.", api);
279279
using WorkspaceSizeFuncType = int (*)(std::decay_t<decltype(convertType(std::declval<Args>()))>..., uint64_t*, aclOpExecutor**);
280280
static const auto workspaceSizeFunc = reinterpret_cast<WorkspaceSizeFuncType>(workspaceSizeFuncAddr);
281281

@@ -288,18 +288,18 @@ void callAclnnImpl(diopiContextHandle_t ctx, const Args&... args) {
288288
aclOpExecutor* executor = nullptr;
289289
auto convertedParams = convertParams(args...);
290290
auto workspaceStatus = std::apply(workspaceSizeFunc, std::tuple_cat(convertedParams.params(), std::make_tuple(&workspaceSize, &executor)));
291-
ASCEND_CHECK_ABORT(workspaceStatus == ACL_SUCCESS, "[%s]'s workspaceStatus is not equal to ACL_SUCCESS. aclnnStatus is %d.", api, workspaceStatus);
291+
ASCEND_CHECK_THROW(workspaceStatus == ACL_SUCCESS, "[%s]'s workspaceStatus is not equal to ACL_SUCCESS. aclnnStatus is %d.", api, workspaceStatus);
292292

293293
AclWorkspace workspace(ctx, workspaceSize);
294294

295295
/* 2. call aclnnXXX function */
296296
static const auto opApiFuncAddr = getOpApiFuncAddr(api);
297-
ASCEND_CHECK_ABORT(opApiFuncAddr != nullptr, "[%s] can't get op function.", api);
297+
ASCEND_CHECK_THROW(opApiFuncAddr != nullptr, "[%s] can't get op function.", api);
298298
using OpApiFuncType = int (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
299299
static const auto opApiFunc = reinterpret_cast<OpApiFuncType>(opApiFuncAddr);
300300

301301
auto ret = opApiFunc(workspace.addr(), workspaceSize, executor, stream);
302-
ASCEND_CHECK_ABORT(ret == ACL_SUCCESS, "[%s] failed. aclnnStatus is %d.", api, ret);
302+
ASCEND_CHECK_THROW(ret == ACL_SUCCESS, "[%s] failed. aclnnStatus is %d.", api, ret);
303303
}
304304

305305
#define DIOPI_ASCEND_CALL_ACLNN(api, ctx, ...) \

impl/ascend/ascend_tensor.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ namespace ascend {
6060
} \
6161
} while (0);
6262

63+
#define ASCEND_CHECK_THROW(condition, ...) \
64+
do { \
65+
if (!(condition)) { \
66+
printf("[%s:%s:%d]: ", __FILE__, __FUNCTION__, __LINE__); \
67+
printf(__VA_ARGS__); \
68+
printf("\n"); \
69+
throw std::runtime_error(std::string("ascend device error:") + aclGetRecentErrMsg()); \
70+
} \
71+
} while (0);
72+
6373
#define ASCEND_CHECK_NULLPTR_ABORT(ptr) ASCEND_CHECK_ABORT(ptr, "Variable is nullptr, pls check.")
6474

6575
inline void error(const char* file, int lineNum, const char* funcName, const char* format, ...) {

impl/ascend/functions/linear.cpp

Lines changed: 44 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -4,111 +4,69 @@
44
* @copyright (c) 2023, DeepLink.
55
*/
66

7-
#include <numeric>
8-
9-
#include "../common/acloprunner.hpp"
7+
#include "../aclnn/acl_scalar.hpp"
8+
#include "../aclnn/adaptor.hpp"
109

1110
namespace impl {
1211
namespace ascend {
1312
diopiError_t diopiLinear(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight,
1413
diopiConstTensorHandle_t bias) {
15-
// convert inputs to AscendTensor
16-
AscendTensor inputCopy(input);
17-
AscendTensor outputCopy(out);
18-
AscendTensor weightCopy(weight);
19-
const std::vector<int64_t> outputPrimaryShape = outputCopy.shape();
20-
21-
if (inputCopy.numel() == 0 || weightCopy.numel() == 0) {
22-
diopiScalar_t zero = constructDiopiScalarT(outputCopy.dtype(), 0.0);
23-
diopiFill(ctx, out, &zero);
24-
return diopiSuccess;
25-
}
26-
27-
// mm's input matrix must be 2D, it needs to be converted if it isn't
28-
if (inputCopy.shape().size() > 2) {
29-
transTensorTo2D(ctx, inputCopy);
30-
}
31-
if (outputCopy.shape().size() > 2) {
32-
transTensorTo2D(ctx, outputCopy);
14+
diopiTensorHandle_t weightT;
15+
diopiSize_t weightSize;
16+
diopiGetTensorShape(weight, &weightSize);
17+
diopiDtype_t weightDtype;
18+
diopiGetTensorDtype(weight, &weightDtype);
19+
std::vector<int64_t> weightTShape(weightSize.data, weightSize.data + weightSize.len);
20+
weightTShape[weightSize.len - 1] = weightSize.data[weightSize.len - 2];
21+
weightTShape[weightSize.len - 2] = weightSize.data[weightSize.len - 1];
22+
diopiSize_t weightTSize = {weightTShape.data(), static_cast<int64_t>(weightTShape.size())};
23+
diopiRequireTensor(ctx, &weightT, &weightTSize, nullptr, weightDtype, diopi_device);
24+
std::vector<int64_t> dims = {1, 0};
25+
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, weight, dims, weightT);
26+
DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, input, weightT, out, 0);
27+
28+
if (nullptr != bias) {
29+
diopiDtype_t outDtype;
30+
diopiGetTensorDtype(out, &outDtype);
31+
diopiScalar_t alpha = constructDiopiScalarT(outDtype, 1);
32+
DIOPI_ASCEND_CALL_ACLNN(aclnnInplaceAdd, ctx, out, bias, &alpha);
3333
}
3434

35-
AclOpRunner<3, 1> runner("MatMulV2", ctx);
36-
runner.addInput(inputCopy).addInput(weightCopy).setAttr<uint8_t>("transpose_x1", false).setAttr<uint8_t>("transpose_x2", true).addOutput(outputCopy);
37-
38-
// if bias is not nullptr, also add bias to input
39-
if (bias) {
40-
runner.addInput(bias);
41-
}
42-
runner.run();
43-
4435
return diopiSuccess;
4536
}
4637

4738
diopiError_t diopiLinearBackward(diopiContextHandle_t ctx, diopiTensorHandle_t gradInput, diopiTensorHandle_t gradWeight, diopiTensorHandle_t gradBias,
4839
diopiConstTensorHandle_t gradOutput, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight) {
49-
AscendTensor gradWeightCopy(gradWeight);
50-
AscendTensor gradOutputCopy(gradOutput);
51-
AscendTensor inputCopy(input);
52-
AscendTensor weightCopy(weight);
53-
54-
const std::vector<int64_t> gradInputPrimaryShape = inputCopy.shape();
55-
bool transTensorTo2DFalg = false;
56-
57-
if (gradOutputCopy.numel() == 0 || weightCopy.numel() == 0 || inputCopy.numel() == 0) {
58-
diopiScalar_t zero = constructDiopiScalarT(inputCopy.dtype(), 0.0);
59-
diopiFill(ctx, gradInput, &zero);
60-
diopiFill(ctx, gradWeight, &zero);
61-
diopiFill(ctx, gradBias, &zero);
62-
return diopiSuccess;
63-
}
64-
65-
if (weightCopy.shape().size() > 2) transTensorTo2D(ctx, weightCopy);
66-
if (gradOutputCopy.shape().size() > 2) transTensorTo2D(ctx, gradOutputCopy);
67-
6840
if (nullptr != gradInput) {
69-
AscendTensor gradInputCopy(gradInput);
70-
if (inputCopy.shape().size() > 2) {
71-
transTensorTo2DFalg = true;
72-
transTensorTo2D(ctx, gradInputCopy);
73-
}
74-
75-
AclOpRunner<2, 1>("MatMul", ctx)
76-
.addInput(gradOutputCopy)
77-
.addInput(weightCopy)
78-
.setAttr<uint8_t>("transpose_x1", false)
79-
.setAttr<uint8_t>("transpose_x2", false)
80-
.addOutput(gradInputCopy)
81-
.run();
82-
83-
if (transTensorTo2DFalg) {
84-
gradInputCopy.view(gradInputPrimaryShape);
85-
}
41+
DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, gradOutput, weight, gradInput, 0);
8642
}
8743

88-
if (inputCopy.shape().size() > 2) transTensorTo2D(ctx, inputCopy);
89-
9044
if (nullptr != gradWeight) {
91-
if (gradWeightCopy.shape().size() > 2) transTensorTo2D(ctx, gradWeightCopy);
92-
93-
AclOpRunner<2, 1>("MatMul", ctx)
94-
.addInput(gradOutputCopy)
95-
.addInput(inputCopy)
96-
.setAttr<uint8_t>("transpose_x1", true)
97-
.setAttr<uint8_t>("transpose_x2", false)
98-
.addOutput(gradWeightCopy)
99-
.run();
45+
AscendTensor input2D(input);
46+
if (input2D.dim() > 2) transTensorTo2D(ctx, input2D);
47+
AscendTensor gradOutput2D(gradOutput);
48+
if (gradOutput2D.dim() > 2) transTensorTo2D(ctx, gradOutput2D);
49+
50+
diopiTensorHandle_t gradOutput2DT;
51+
std::vector<int64_t> gradOutput2DTShape = {gradOutput2D.shape()[1], gradOutput2D.shape()[0]};
52+
diopiSize_t gradOutput2DTSize = {gradOutput2DTShape.data(), static_cast<int64_t>(gradOutput2DTShape.size())};
53+
diopiRequireTensor(ctx, &gradOutput2DT, &gradOutput2DTSize, nullptr, gradOutput2D.dtype(), diopi_device);
54+
55+
std::vector<int64_t> dims = {1, 0};
56+
DIOPI_ASCEND_CALL_ACLNN(aclnnPermute, ctx, gradOutput2D, dims, gradOutput2DT);
57+
DIOPI_ASCEND_CALL_ACLNN(aclnnMatmul, ctx, gradOutput2DT, input2D, gradWeight, 0);
10058
}
10159

102-
AscendTensor reshapedGradOutputCopy;
103-
makeTensorLike(ctx, reshapedGradOutputCopy, gradOutputCopy, gradOutputCopy.dtype());
104-
reshape(ctx, gradOutputCopy, reshapedGradOutputCopy, gradOutputCopy.shape());
60+
if (nullptr != gradBias) {
61+
diopiSize_t gradOutputSize;
62+
diopiGetTensorShape(gradOutput, &gradOutputSize);
63+
std::vector<int64_t> dims(gradOutputSize.len - 1);
64+
std::iota(std::begin(dims), std::end(dims), 0);
10565

106-
diopiTensorHandle_t diopiGradOutputCopy = const_cast<diopiTensorHandle_t>(reshapedGradOutputCopy.tensorHandle());
107-
if (gradBias) {
108-
std::vector<int64_t> dimVec(gradOutputCopy.shape().size() - 1);
109-
std::iota(std::begin(dimVec), std::end(dimVec), 0);
110-
diopiSize_t dim = vectorToDiopiSize(dimVec);
111-
diopiSum(ctx, gradBias, diopiGradOutputCopy, dim);
66+
diopiDtype_t biasDtype;
67+
diopiGetTensorDtype(gradBias, &biasDtype);
68+
aclDataType dtype = getAclDataType(biasDtype);
69+
DIOPI_ASCEND_CALL_ACLNN(aclnnReduceSum, ctx, gradOutput, dims, false, dtype, gradBias);
11270
}
11371

11472
return diopiSuccess;

0 commit comments

Comments
 (0)