|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -#include <executorch/backends/cadence/generic/kernels/kernels.h> |
| 9 | +#include <executorch/backends/cadence/generic/operators/op_requantize.h> |
| 10 | +#include <executorch/runtime/core/exec_aten/exec_aten.h> |
10 | 11 | #include <executorch/runtime/kernel/kernel_includes.h> |
11 | 12 |
|
| 13 | +#include <executorch/backends/cadence/generic/kernels/kernels.h> |
| 14 | +#include <cstdint> |
| 15 | +#include <cstdlib> |
| 16 | + |
12 | 17 | namespace impl { |
13 | 18 | namespace generic { |
14 | 19 | namespace native { |
15 | 20 |
|
16 | | -using executorch::aten::ScalarType; |
17 | | -using executorch::aten::Tensor; |
18 | | -using executorch::runtime::KernelRuntimeContext; |
| 21 | +using ::executorch::aten::IntArrayRef; |
| 22 | +using ::executorch::aten::optional; |
| 23 | +using ::executorch::aten::Scalar; |
| 24 | +using ::executorch::aten::ScalarType; |
| 25 | +using ::executorch::aten::Tensor; |
| 26 | +using ::executorch::runtime::KernelRuntimeContext; |
| 27 | +using ::impl::generic::kernels::dequantize; |
| 28 | +using ::impl::generic::kernels::quantize; |
19 | 29 |
|
20 | 30 | // Requantize the int8_t/uint8_t input tensor to a uint8_t/int8_t out tensor. |
21 | 31 | // The scale and zero_point for requantization are in the args. |
@@ -86,15 +96,14 @@ Tensor& requantize_out( |
86 | 96 | torch::executor::toString(out.scalar_type()), |
87 | 97 | torch::executor::toString(out_dtype)); |
88 | 98 |
|
89 | | -#define typed_requantize(ctype, dtype) \ |
90 | | - const ctype* input_data = input.const_data_ptr<ctype>(); \ |
91 | | - dtype* out_data = out.mutable_data_ptr<dtype>(); \ |
92 | | - for (size_t i = 0; i < numel; ++i) { \ |
93 | | - float dequant = \ |
94 | | - kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \ |
95 | | - out_data[i] = \ |
96 | | - kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \ |
| 99 | +#define typed_requantize(ctype, dtype) \ |
| 100 | + const ctype* input_data = input.const_data_ptr<ctype>(); \ |
| 101 | + dtype* out_data = out.mutable_data_ptr<dtype>(); \ |
| 102 | + for (size_t i = 0; i < numel; ++i) { \ |
| 103 | + float dequant = dequantize<ctype>(input_data[i], in_scale, in_zero_point); \ |
| 104 | + out_data[i] = quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \ |
97 | 105 | }; |
| 106 | + |
98 | 107 | #define typed_requantize_in(ctype) \ |
99 | 108 | switch (out_dtype) { \ |
100 | 109 | case ScalarType::Byte: { \ |
@@ -187,14 +196,12 @@ Tensor& requantize_per_tensor_out( |
187 | 196 | torch::executor::toString(out.scalar_type()), |
188 | 197 | torch::executor::toString(out_dtype)); |
189 | 198 |
|
190 | | -#define typed_requantize(ctype, dtype) \ |
191 | | - const ctype* input_data = input.const_data_ptr<ctype>(); \ |
192 | | - dtype* out_data = out.mutable_data_ptr<dtype>(); \ |
193 | | - for (size_t i = 0; i < numel; ++i) { \ |
194 | | - float dequant = \ |
195 | | - kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \ |
196 | | - out_data[i] = \ |
197 | | - kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \ |
| 199 | +#define typed_requantize(ctype, dtype) \ |
| 200 | + const ctype* input_data = input.const_data_ptr<ctype>(); \ |
| 201 | + dtype* out_data = out.mutable_data_ptr<dtype>(); \ |
| 202 | + for (size_t i = 0; i < numel; ++i) { \ |
| 203 | + float dequant = dequantize<ctype>(input_data[i], in_scale, in_zero_point); \ |
| 204 | + out_data[i] = quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \ |
198 | 205 | }; |
199 | 206 |
|
200 | 207 | #define typed_requantize_in(ctype) \ |
|
0 commit comments