Skip to content

Commit d494a88

Browse files
authored
[Ascend] fuj/fix-AscendTensor-storage-offset (#1259)
1 parent d820aad commit d494a88

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

impl/ascend/aclnn/adaptor.hpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,42 @@ inline void* getOpApiFuncAddr(const char* apiName) {
5656
return getOpApiFuncAddrInLib(opApiHandler, kOpApiLibName, apiName);
5757
}
5858

59+
inline aclFormat storageFormatByDimNum(int64_t dimNum) {
60+
aclFormat format = ACL_FORMAT_ND;
61+
switch (dimNum) {
62+
case 3:
63+
format = ACL_FORMAT_NCL;
64+
break;
65+
case 4:
66+
format = ACL_FORMAT_NCHW;
67+
break;
68+
case 5:
69+
format = ACL_FORMAT_NCDHW;
70+
break;
71+
default:
72+
format = ACL_FORMAT_ND;
73+
}
74+
return format;
75+
}
76+
5977
inline aclTensor* createAclTensorFromAscendTensor(const AscendTensor& input) {
6078
const auto& shape = input.shape();
6179
const auto& stride = input.stride();
6280
const auto storageSize = static_cast<int64_t>(input.storageNbytes() / input.elemsize());
81+
82+
void* storagePtr = nullptr;
83+
diopiGetTensorStoragePtr(input.tensorHandle(), &storagePtr);
84+
auto format = storageFormatByDimNum(input.dim());
85+
6386
return ::aclCreateTensor(shape.data(),
6487
shape.size(),
6588
input.getAclDataType(),
6689
stride.data(),
6790
input.storageOffset(),
68-
input.getAclDataFormat(), // TODO(lljbash): op_plugin assume non-channel-last, why?
91+
format, // input.getAclDataFormat(), // TODO(lljbash): op_plugin assume non-channel-last, why?
6992
&storageSize,
7093
/*storageDimsNum=*/1,
71-
const_cast<void*>(input.data()));
94+
const_cast<void*>(storagePtr));
7295
}
7396

7497
inline aclTensor* createAclTensorFromDiopiTensor(diopiConstTensorHandle_t tensor) {
@@ -90,20 +113,15 @@ inline aclTensor* createAclTensorFromDiopiTensor(diopiConstTensorHandle_t tensor
90113
diopiGetTensorStorageOffset(tensor, &storageOffset);
91114
std::size_t storageNbytes{};
92115
diopiGetTensorStorageNbytes(tensor, &storageNbytes);
93-
const void* tensorData = nullptr;
94-
diopiGetTensorDataConst(tensor, &tensorData);
116+
117+
void* storagePtr = nullptr;
118+
diopiGetTensorStoragePtr(tensor, &storagePtr);
119+
95120
auto type = diopiDtypeToAclDataType(dtype);
96-
auto format = inferAclDataFormat(shape.len, shape.data, stride.data);
121+
auto format = storageFormatByDimNum(shape.len);
97122
auto storageSize = static_cast<int64_t>(storageNbytes / elemsize);
98-
return ::aclCreateTensor(shape.data,
99-
shape.len,
100-
type,
101-
stride.data,
102-
storageOffset,
103-
format,
104-
&storageSize,
105-
/*storageDimsNum=*/1,
106-
const_cast<void*>(tensorData));
123+
124+
return ::aclCreateTensor(shape.data, shape.len, type, stride.data, storageOffset, format, &storageSize, 1, const_cast<void*>(storagePtr));
107125
}
108126

109127
inline aclScalar* createAclScalarFromDiopiScalar(const diopiScalar_t* scalar) {

impl/ascend/convert_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333

3434
- diopiConvolution2d:
3535
dtype: (float64)->float32
36+
layout: ND
3637

3738
- diopiConvolution2dBackward:
3839
dtype: (float64)->float32
40+
layout: ND
3941

4042
- diopiAdaptiveAvgPool2d:
4143
dtype: (float64)->float32

0 commit comments

Comments
 (0)