|
12 | 12 | namespace impl { |
13 | 13 | namespace ascend { |
14 | 14 | void formatCastInsideGroup(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t in) { |
15 | | - AscendTensor tensor_in(in); |
16 | | - AscendTensor tensor_out(out); |
17 | | - AclOpRunner<1, 1>("Identity", ctx).addInput(tensor_in).addOutput(tensor_out).run(); |
| 15 | + AscendTensor tensorIn(in); |
| 16 | + AscendTensor tensorOut(out); |
| 17 | + AclOpRunner<1, 1>("Identity", ctx).addInput(tensorIn).addOutput(tensorOut).run(); |
18 | 18 | } |
19 | 19 | void formatCastBetweenGroup(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t in) { |
20 | | - diopiStorageDesc_t desc_in; |
21 | | - diopiGetTensorStorageDesc(in, &desc_in); |
22 | | - diopiStorageDesc_t desc_out; |
23 | | - diopiGetTensorStorageDesc(out, &desc_out); |
24 | | - bool isInputBaseFormat = FormatHelper::isBaseFormat(desc_in.format); |
25 | | - bool isOutputBaseFormat = FormatHelper::isBaseFormat(desc_out.format); |
| 20 | + diopiStorageDesc_t descIn; |
| 21 | + diopiGetTensorStorageDesc(in, &descIn); |
| 22 | + diopiStorageDesc_t descOut; |
| 23 | + diopiGetTensorStorageDesc(out, &descOut); |
| 24 | + bool isInputBaseFormat = FormatHelper::isBaseFormat(descIn.format); |
| 25 | + bool isOutputBaseFormat = FormatHelper::isBaseFormat(descOut.format); |
26 | 26 | if (isInputBaseFormat && !isOutputBaseFormat) { |
27 | | - diopiMemoryFormat_t input_format_tmp = desc_in.format; |
28 | | - desc_in.format = FormatHelper::getDiopiBaseFormat(desc_out.format); |
29 | | - diopiSetTensorStorageDesc(in, desc_in); |
| 27 | + diopiMemoryFormat_t inputFormatTmp = descIn.format; |
| 28 | + descIn.format = FormatHelper::getDiopiBaseFormat(descOut.format); |
| 29 | + diopiSetTensorStorageDesc(in, descIn); |
30 | 30 | formatCastInsideGroup(ctx, out, in); |
31 | | - desc_in.format = input_format_tmp; |
32 | | - diopiSetTensorStorageDesc(in, desc_in); |
| 31 | + descIn.format = inputFormatTmp; |
| 32 | + diopiSetTensorStorageDesc(in, descIn); |
33 | 33 | } else if (!isInputBaseFormat && isOutputBaseFormat) { |
34 | | - diopiMemoryFormat_t out_format_tmp = desc_out.format; |
35 | | - desc_out.format = FormatHelper::getDiopiBaseFormat(desc_in.format); |
36 | | - diopiSetTensorStorageDesc(out, desc_out); |
| 34 | + diopiMemoryFormat_t outFormatTmp = descOut.format; |
| 35 | + descOut.format = FormatHelper::getDiopiBaseFormat(descIn.format); |
| 36 | + diopiSetTensorStorageDesc(out, descOut); |
37 | 37 | formatCastInsideGroup(ctx, out, in); |
38 | | - desc_out.format = out_format_tmp; |
39 | | - diopiSetTensorStorageDesc(out, desc_out); |
| 38 | + descOut.format = outFormatTmp; |
| 39 | + diopiSetTensorStorageDesc(out, descOut); |
40 | 40 | } else { |
41 | 41 | ASCEND_CHECK_ABORT(false, "format cast not support"); |
42 | 42 | } |
43 | 43 | } |
44 | 44 |
|
45 | | -diopiError_t diopiFormatCast(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiTensorHandle_t in, diopiMemoryFormat_t target_format) { |
46 | | - AscendTensor tensor_in(in); |
47 | | - if (tensor_in.storageFormat() == target_format) { |
| 45 | +diopiError_t diopiFormatCast(diopiContextHandle_t ctx, diopiTensorHandle_t* out, diopiTensorHandle_t in, diopiMemoryFormat_t targetFormat) { |
| 46 | + AscendTensor tensorIn(in); |
| 47 | + if (tensorIn.storageFormat() == targetFormat) { |
48 | 48 | *out = in; |
49 | 49 | return diopiSuccess; |
50 | 50 | } |
51 | | - if (FormatHelper::isBaseFormat(tensor_in.storageFormat()) && FormatHelper::isBaseFormat(target_format)) { |
| 51 | + if (FormatHelper::isBaseFormat(tensorIn.storageFormat()) && FormatHelper::isBaseFormat(targetFormat)) { |
52 | 52 | diopiStorageDesc_t desc; |
53 | 53 | diopiGetTensorStorageDesc(in, &desc); |
54 | | - desc.format = target_format; |
| 54 | + desc.format = targetFormat; |
55 | 55 | *out = in; |
56 | 56 | diopiSetTensorStorageDesc(*out, desc); |
57 | 57 | return diopiSuccess; |
58 | 58 | } |
59 | | - std::vector<int64_t> storage_sizes_out = FormatHelper::getStorageSizes(target_format, tensor_in.shape()); |
60 | | - diopiStorageDesc_t desc_out; |
61 | | - desc_out.sizes.data = storage_sizes_out.data(); |
62 | | - desc_out.sizes.len = storage_sizes_out.size(); |
63 | | - desc_out.format = target_format; |
64 | | - diopiRequireTensor(ctx, out, &desc_out.sizes, nullptr, tensor_in.dtype(), tensor_in.device()); |
65 | | - diopiSetTensorStorageDesc(*out, desc_out); |
| 59 | + std::vector<int64_t> storageSizesOut = FormatHelper::getStorageSizes(targetFormat, tensorIn.shape()); |
| 60 | + diopiStorageDesc_t descOut; |
| 61 | + descOut.sizes.data = storageSizesOut.data(); |
| 62 | + descOut.sizes.len = storageSizesOut.size(); |
| 63 | + descOut.format = targetFormat; |
| 64 | + diopiRequireTensor(ctx, out, &descOut.sizes, nullptr, tensorIn.dtype(), tensorIn.device()); |
| 65 | + diopiSetTensorStorageDesc(*out, descOut); |
66 | 66 | // set tensor metadata |
67 | 67 | diopiCopyTensorMetaData(*out, in); |
68 | | - if (FormatHelper::getDiopiBaseFormat(target_format) != FormatHelper::getDiopiBaseFormat(tensor_in.storageFormat())) { |
| 68 | + if (FormatHelper::getDiopiBaseFormat(targetFormat) != FormatHelper::getDiopiBaseFormat(tensorIn.storageFormat())) { |
69 | 69 | formatCastBetweenGroup(ctx, *out, in); |
70 | 70 | } else { |
71 | 71 | formatCastInsideGroup(ctx, *out, in); |
|
0 commit comments