|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <executorch/backends/cadence/generic/operators/op_avg_pool2d.h> |
| 10 | + |
| 11 | +#include <algorithm> |
| 12 | +#include <cmath> |
| 13 | + |
| 14 | +#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
| 15 | +#include <executorch/runtime/core/exec_aten/util/tensor_util.h> |
| 16 | + |
| 17 | +namespace impl { |
| 18 | +namespace generic { |
| 19 | +namespace native { |
| 20 | + |
| 21 | +using ::executorch::aten::IntArrayRef; |
| 22 | +using ::executorch::aten::optional; |
| 23 | +using ::executorch::aten::ScalarType; |
| 24 | +using ::executorch::aten::Tensor; |
| 25 | +using ::executorch::runtime::getLeadingDims; |
| 26 | +using ::executorch::runtime::KernelRuntimeContext; |
| 27 | + |
| 28 | +// Compute the avg_pool2d for in_data in NCHW layout. IT is the input datatype, |
| 29 | +// and AT is the accumulation datatype. 'quantized' is true when the input is |
| 30 | +// quantized tensor. |
| 31 | +template <typename IT, typename AT = IT, bool quantized = false> |
| 32 | +void avg_pool2d_nchw( |
| 33 | + const IT* __restrict__ in_data, |
| 34 | + const int32_t in_zero_point, |
| 35 | + IT* __restrict__ out_data, |
| 36 | + IntArrayRef kernel_size, |
| 37 | + IntArrayRef stride, |
| 38 | + IntArrayRef padding, |
| 39 | + bool count_include_pad, |
| 40 | + int64_t divisor, |
| 41 | + int leading_dims, |
| 42 | + int ih, |
| 43 | + int iw, |
| 44 | + int oh, |
| 45 | + int ow) { |
| 46 | + int kh = kernel_size[0]; |
| 47 | + int kw = kernel_size[1]; |
| 48 | + int s0 = stride[0]; |
| 49 | + int s1 = stride[1]; |
| 50 | + int p0 = padding[0]; |
| 51 | + int p1 = padding[1]; |
| 52 | + |
| 53 | + for (int _n = 0; _n < leading_dims; ++_n) { |
| 54 | + for (int _ih = 0, _oh = 0; _oh < oh; ++_oh, _ih += s0) { |
| 55 | + int input_offset = _n * ih * iw; |
| 56 | + int output_offset = _n * oh * ow + _oh * ow; |
| 57 | + for (int _iw = 0, _ow = 0; _ow < ow; ++_ow, _iw += s1) { |
| 58 | + int kh_lo = std::max(0, _ih - p0); |
| 59 | + int kh_hi = std::min(ih, _ih + kh - p0); |
| 60 | + int kw_lo = std::max(0, _iw - p1); |
| 61 | + int kw_hi = std::min(iw, _iw + kw - p1); |
| 62 | + // Count the number of contributions sans padding |
| 63 | + int count = (kh_hi - kh_lo) * (kw_hi - kw_lo); |
| 64 | + // Set the accumulator |
| 65 | + AT acc = count_include_pad ? in_zero_point * (kh * kw - count) : 0; |
| 66 | + // Accumulate values |
| 67 | + for (int _kh = kh_lo; _kh < kh_hi; ++_kh) { |
| 68 | + for (int _kw = kw_lo; _kw < kw_hi; ++_kw) { |
| 69 | + int input_addr = input_offset + _kh * iw + _kw; |
| 70 | + acc += in_data[input_addr]; |
| 71 | + } |
| 72 | + } |
| 73 | + // The divisor changes depending on whether the count includes |
| 74 | + // padded cells or not. |
| 75 | + float inv_divisor = 1. / (count_include_pad ? divisor : count); |
| 76 | + float val = acc * inv_divisor; |
| 77 | + if (quantized) { |
| 78 | + int32_t min_val = |
| 79 | + static_cast<int32_t>(std::numeric_limits<IT>::min()); |
| 80 | + int32_t max_val = |
| 81 | + static_cast<int32_t>(std::numeric_limits<IT>::max()); |
| 82 | + out_data[output_offset + _ow] = std::min( |
| 83 | + std::max(int32_t(std::nearbyint(val)), min_val), max_val); |
| 84 | + } else { |
| 85 | + out_data[output_offset + _ow] = val; |
| 86 | + } |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +Tensor& avg_pool2d_out( |
| 93 | + ET_UNUSED KernelRuntimeContext& ctx, |
| 94 | + const Tensor& input, |
| 95 | + IntArrayRef kernel_size, |
| 96 | + IntArrayRef stride, |
| 97 | + IntArrayRef padding, |
| 98 | + bool ceil_mode, |
| 99 | + bool count_include_pad, |
| 100 | + optional<int64_t> divisor_override, |
| 101 | + const optional<Tensor>& in_zero_point_t, |
| 102 | + bool channel_last, |
| 103 | + Tensor& out) { |
| 104 | + ET_DCHECK_MSG(!channel_last, "NHWC layout for avg_pool2d not yet supported"); |
| 105 | + const int32_t in_zero_point = in_zero_point_t.has_value() |
| 106 | + ? in_zero_point_t.value().const_data_ptr<int32_t>()[0] |
| 107 | + : 0; |
| 108 | + const int64_t divisor = divisor_override.has_value() |
| 109 | + ? divisor_override.value() |
| 110 | + : kernel_size[0] * kernel_size[1]; |
| 111 | + |
| 112 | + const int odim = out.dim(); |
| 113 | + const int on = getLeadingDims(out, odim - 2); |
| 114 | + const int oh = out.size(odim - 2); |
| 115 | + const int ow = out.size(odim - 1); |
| 116 | + const int ih = input.size(odim - 2); |
| 117 | + const int iw = input.size(odim - 1); |
| 118 | + |
| 119 | + // We generate the kernel for float and uint8_t types. The operator also |
| 120 | + // works for double, but does not support other dtypes. |
| 121 | +#define typed_avg_pool2d(btype, ctype, quantized, dtype) \ |
| 122 | + case ScalarType::dtype: { \ |
| 123 | + avg_pool2d_nchw<btype, ctype, quantized>( \ |
| 124 | + input.const_data_ptr<btype>(), \ |
| 125 | + in_zero_point, \ |
| 126 | + out.mutable_data_ptr<btype>(), \ |
| 127 | + kernel_size, \ |
| 128 | + stride, \ |
| 129 | + padding, \ |
| 130 | + count_include_pad, \ |
| 131 | + divisor, \ |
| 132 | + on, \ |
| 133 | + ih, \ |
| 134 | + iw, \ |
| 135 | + oh, \ |
| 136 | + ow); \ |
| 137 | + break; \ |
| 138 | + } |
| 139 | + |
| 140 | + ScalarType dtype = input.scalar_type(); |
| 141 | + switch (dtype) { |
| 142 | + typed_avg_pool2d(float, float, false, Float); |
| 143 | + typed_avg_pool2d(uint8_t, int32_t, true, Byte); |
| 144 | + default: |
| 145 | + ET_DCHECK_MSG( |
| 146 | + false, |
| 147 | + "avg_pool2d not implemented for dtype %s", |
| 148 | + torch::executor::toString(dtype)); |
| 149 | + } |
| 150 | + |
| 151 | + return out; |
| 152 | +} |
| 153 | + |
| 154 | +} // namespace native |
| 155 | +} // namespace generic |
| 156 | +} // namespace impl |
0 commit comments