Skip to content

Commit 1eb8f44

Browse files
authored
zcx/fix format and use ascend npu (DeepLink-org#787)
fix format and use ascend npu
1 parent ba39537 commit 1eb8f44

File tree

4 files changed

+35
-32
lines changed

4 files changed

+35
-32
lines changed

impl/ascend/common/format_helper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ aclFormat FormatHelper::getAclFormat(diopiMemoryFormat_t memoryFormat) {
6060
case diopiMemoryFormat_t::Undefined:
6161
return aclFormat::ACL_FORMAT_UNDEFINED;
6262
case diopiMemoryFormat_t::Contiguous:
63+
case diopiMemoryFormat_t::ND:
6364
return aclFormat::ACL_FORMAT_ND;
6465
case diopiMemoryFormat_t::NCHW:
6566
return aclFormat::ACL_FORMAT_NCHW;

impl/ascend/functions/format_cast.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,60 +12,60 @@
1212
namespace impl {
1313
namespace ascend {
1414
void formatCastInsideGroup(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t in) {
15-
AscendTensor tensor_in(in);
16-
AscendTensor tensor_out(out);
17-
AclOpRunner<1, 1>("Identity", ctx).addInput(tensor_in).addOutput(tensor_out).run();
15+
AscendTensor tensorIn(in);
16+
AscendTensor tensorOut(out);
17+
AclOpRunner<1, 1>("Identity", ctx).addInput(tensorIn).addOutput(tensorOut).run();
1818
}
1919
void formatCastBetweenGroup(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t in) {
20-
diopiStorageDesc_t desc_in;
21-
diopiGetTensorStorageDesc(in, &desc_in);
22-
diopiStorageDesc_t desc_out;
23-
diopiGetTensorStorageDesc(out, &desc_out);
24-
bool isInputBaseFormat = FormatHelper::isBaseFormat(desc_in.format);
25-
bool isOutputBaseFormat = FormatHelper::isBaseFormat(desc_out.format);
20+
diopiStorageDesc_t descIn;
21+
diopiGetTensorStorageDesc(in, &descIn);
22+
diopiStorageDesc_t descOut;
23+
diopiGetTensorStorageDesc(out, &descOut);
24+
bool isInputBaseFormat = FormatHelper::isBaseFormat(descIn.format);
25+
bool isOutputBaseFormat = FormatHelper::isBaseFormat(descOut.format);
2626
if (isInputBaseFormat && !isOutputBaseFormat) {
27-
diopiMemoryFormat_t input_format_tmp = desc_in.format;
28-
desc_in.format = FormatHelper::getDiopiBaseFormat(desc_out.format);
29-
diopiSetTensorStorageDesc(in, desc_in);
27+
diopiMemoryFormat_t inputFormatTmp = descIn.format;
28+
descIn.format = FormatHelper::getDiopiBaseFormat(descOut.format);
29+
diopiSetTensorStorageDesc(in, descIn);
3030
formatCastInsideGroup(ctx, out, in);
31-
desc_in.format = input_format_tmp;
32-
diopiSetTensorStorageDesc(in, desc_in);
31+
descIn.format = inputFormatTmp;
32+
diopiSetTensorStorageDesc(in, descIn);
3333
} else if (!isInputBaseFormat && isOutputBaseFormat) {
34-
diopiMemoryFormat_t out_format_tmp = desc_out.format;
35-
desc_out.format = FormatHelper::getDiopiBaseFormat(desc_in.format);
36-
diopiSetTensorStorageDesc(out, desc_out);
34+
diopiMemoryFormat_t outFormatTmp = descOut.format;
35+
descOut.format = FormatHelper::getDiopiBaseFormat(descIn.format);
36+
diopiSetTensorStorageDesc(out, descOut);
3737
formatCastInsideGroup(ctx, out, in);
38-
desc_out.format = out_format_tmp;
39-
diopiSetTensorStorageDesc(out, desc_out);
38+
descOut.format = outFormatTmp;
39+
diopiSetTensorStorageDesc(out, descOut);
4040
} else {
4141
ASCEND_CHECK_ABORT(false, "format cast not support");
4242
}
4343
}
4444

45-
diopiError_t diopiFormatCast(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiTensorHandle_t in, diopiMemoryFormat_t target_format) {
46-
AscendTensor tensor_in(in);
47-
if (tensor_in.storageFormat() == target_format) {
45+
diopiError_t diopiFormatCast(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiTensorHandle_t in, diopiMemoryFormat_t targetFormat) {
46+
AscendTensor tensorIn(in);
47+
if (tensorIn.storageFormat() == targetFormat) {
4848
*out = in;
4949
return diopiSuccess;
5050
}
51-
if (FormatHelper::isBaseFormat(tensor_in.storageFormat()) && FormatHelper::isBaseFormat(target_format)) {
51+
if (FormatHelper::isBaseFormat(tensorIn.storageFormat()) && FormatHelper::isBaseFormat(targetFormat)) {
5252
diopiStorageDesc_t desc;
5353
diopiGetTensorStorageDesc(in, &desc);
54-
desc.format = target_format;
54+
desc.format = targetFormat;
5555
*out = in;
5656
diopiSetTensorStorageDesc(*out, desc);
5757
return diopiSuccess;
5858
}
59-
std::vector<int64_t> storage_sizes_out = FormatHelper::getStorageSizes(target_format, tensor_in.shape());
60-
diopiStorageDesc_t desc_out;
61-
desc_out.sizes.data = storage_sizes_out.data();
62-
desc_out.sizes.len = storage_sizes_out.size();
63-
desc_out.format = target_format;
64-
diopiRequireTensor(ctx, out, &desc_out.sizes, nullptr, tensor_in.dtype(), tensor_in.device());
65-
diopiSetTensorStorageDesc(*out, desc_out);
59+
std::vector<int64_t> storageSizesOut = FormatHelper::getStorageSizes(targetFormat, tensorIn.shape());
60+
diopiStorageDesc_t descOut;
61+
descOut.sizes.data = storageSizesOut.data();
62+
descOut.sizes.len = storageSizesOut.size();
63+
descOut.format = targetFormat;
64+
diopiRequireTensor(ctx, out, &descOut.sizes, nullptr, tensorIn.dtype(), tensorIn.device());
65+
diopiSetTensorStorageDesc(*out, descOut);
6666
// set tensor metadata
6767
diopiCopyTensorMetaData(*out, in);
68-
if (FormatHelper::getDiopiBaseFormat(target_format) != FormatHelper::getDiopiBaseFormat(tensor_in.storageFormat())) {
68+
if (FormatHelper::getDiopiBaseFormat(targetFormat) != FormatHelper::getDiopiBaseFormat(tensorIn.storageFormat())) {
6969
formatCastBetweenGroup(ctx, *out, in);
7070
} else {
7171
formatCastInsideGroup(ctx, *out, in);

impl/ascend_npu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ set(OLD_IMPL_SRC
648648
${OLD_IMPL_DIR}/functions/linspace.cpp
649649
${OLD_IMPL_DIR}/functions/apply_penalty.cpp
650650
${OLD_IMPL_DIR}/functions/split.cpp
651+
${OLD_IMPL_DIR}/functions/format_cast.cpp
651652
${OLD_IMPL_DIR}/functions_ext/rms_norm.cpp
652653
#${OLD_IMPL_DIR}/test/export_functions.cpp
653654
#${OLD_IMPL_DIR}/test/conform_test.cpp

impl/ascend_npu/ascend_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ ascend:
203203
- diopiScatterScalar
204204
- diopiScatterInpScalar
205205
- diopiApplyPenalty
206+
- diopiFormatCast
206207
ascend_npu:
207208
- diopiRemainderTensor
208209
- diopiRemainderScalar

0 commit comments

Comments
 (0)