1+ /*
2+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3+ * All rights reserved.
4+ *
5+ * This source code is licensed under the BSD-style license found in the
6+ * LICENSE file in the root directory of this source tree.
7+ */
8+
9+ #pragma once
10+
11+ #include < executorch/runtime/core/exec_aten/exec_aten.h>
12+ #include < executorch/runtime/core/exec_aten/util/tensor_util.h>
13+ #include < executorch/runtime/kernel/kernel_includes.h>
14+
15+ namespace torch {
16+ namespace executor {
17+
18+ // Ported from aten/src/ATen/native/GridSampler.h
19+ // note that these need to be in the SAME ORDER as the enum in GridSampler.h
20+ // as they are mapped to integer values (0, 1, 2) in this order
21+ enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
22+ enum class GridSamplerPadding { Zeros, Border, Reflection };
23+
24+ // Ported from aten/src/ATen/native/GridSampler.h
25+ // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
26+ // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
27+ // if align_corners: -1 and +1 get sent to the centers of the corner pixels
28+ // -1 --> 0
29+ // +1 --> (size - 1)
30+ // scale_factor = (size - 1) / 2
31+ // if not align_corners: -1 and +1 get sent to the image edges
32+ // -1 --> -0.5
33+ // +1 --> (size - 1) + 0.5 == size - 0.5
34+ // scale_factor = size / 2
35+ template <typename scalar_t >
36+ inline scalar_t
37+ grid_sampler_unnormalize (scalar_t coord, int64_t size, bool align_corners) {
38+ if (align_corners) {
39+ // unnormalize coord from [-1, 1] to [0, size - 1]
40+ return ((coord + 1 ) / 2 ) * (size - 1 );
41+ } else {
42+ // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
43+ return ((coord + 1 ) * size - 1 ) / 2 ;
44+ }
45+ }
46+
47+ // Ported from aten/src/ATen/native/GridSampler.h
48+ // Clips coordinates to between 0 and clip_limit - 1
49+ template <typename scalar_t >
50+ inline scalar_t clip_coordinates (scalar_t in, int64_t clip_limit) {
51+ return std::min (
52+ static_cast <scalar_t >(clip_limit - 1 ),
53+ std::max (in, static_cast <scalar_t >(0 )));
54+ }
55+
56+ // Ported from aten/src/ATen/native/GridSampler.h
57+ // Reflects coordinates until they fall between low and high (inclusive).
58+ // The bounds are passed as twice their value so that half-integer values
59+ // can be represented as ints.
60+ template <typename scalar_t >
61+ inline scalar_t
62+ reflect_coordinates (scalar_t in, int64_t twice_low, int64_t twice_high) {
63+ if (twice_low == twice_high) {
64+ return static_cast <scalar_t >(0 );
65+ }
66+ scalar_t min = static_cast <scalar_t >(twice_low) / 2 ;
67+ scalar_t span = static_cast <scalar_t >(twice_high - twice_low) / 2 ;
68+ in = std::fabs (in - min);
69+ // `fmod` returns same sign as `in`, which is positive after the `fabs` above.
70+ scalar_t extra = std::fmod (in, span);
71+ int flips = static_cast <int >(std::floor (in / span));
72+ if (flips % 2 == 0 ) {
73+ return extra + min;
74+ } else {
75+ return span - extra + min;
76+ }
77+ }
78+
79+ // Ported from aten/src/ATen/native/GridSampler.h
80+ // Computes the pixel source index value for a grid coordinate
81+ template <typename scalar_t >
82+ inline scalar_t grid_sampler_compute_source_index (
83+ scalar_t coord,
84+ int64_t size,
85+ GridSamplerPadding padding_mode,
86+ bool align_corners) {
87+ coord = grid_sampler_unnormalize (coord, size, align_corners);
88+ if (padding_mode == GridSamplerPadding::Border) {
89+ // clip coordinates to image borders
90+ coord = clip_coordinates (coord, size);
91+ } else if (padding_mode == GridSamplerPadding::Reflection) {
92+ // reflect coordinates by image borders
93+ if (align_corners) {
94+ coord = reflect_coordinates (coord, 0 , 2 * (size - 1 ));
95+ } else {
96+ coord = reflect_coordinates (coord, -1 , 2 * size - 1 );
97+ }
98+ coord = clip_coordinates (coord, size);
99+ }
100+ return coord;
101+ }
102+
103+ // Ported from aten/src/ATen/native/GridSampler.h
104+ // Check if coordinates are within bounds [0, limit-1]
105+ template <typename scalar_t >
106+ inline bool within_bounds_2d (scalar_t h, scalar_t w, int64_t H, int64_t W) {
107+ return h >= 0 && h < H && w >= 0 && w < W;
108+ }
109+
110+ // Ported from aten/src/ATen/native/UpSample.h
111+ // Cubic convolution function 1 (for points within 1 unit of the point)
112+ template <typename scalar_t >
113+ inline scalar_t cubic_convolution1 (scalar_t x, scalar_t A) {
114+ return ((A + 2 ) * x - (A + 3 )) * x * x + 1 ;
115+ }
116+
117+ // Ported from aten/src/ATen/native/UpSample.h
118+ // Cubic convolution function 2 (for points between 1 and 2 units from the
119+ // point)
120+ template <typename scalar_t >
121+ inline scalar_t cubic_convolution2 (scalar_t x, scalar_t A) {
122+ return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
123+ }
124+
125+ // Ported from aten/src/ATen/native/UpSample.h
126+ // Computes the 4 cubic interpolation coefficients for a given position t in [0,
127+ // 1]
128+ template <typename scalar_t >
129+ inline void get_cubic_upsample_coefficients (scalar_t coeffs[4 ], scalar_t t) {
130+ // Standard bicubic interpolation uses alpha = -0.75
131+ scalar_t A = static_cast <scalar_t >(-0.75 );
132+
133+ scalar_t x1 = t;
134+ coeffs[0 ] = cubic_convolution2<scalar_t >(x1 + static_cast <scalar_t >(1.0 ), A);
135+ coeffs[1 ] = cubic_convolution1<scalar_t >(x1, A);
136+
137+ scalar_t x2 = static_cast <scalar_t >(1.0 ) - t;
138+ coeffs[2 ] = cubic_convolution1<scalar_t >(x2, A);
139+ coeffs[3 ] = cubic_convolution2<scalar_t >(x2 + static_cast <scalar_t >(1.0 ), A);
140+ }
141+
142+ // Ported from aten/src/ATen/native/UpSample.h
143+ // Performs 1D cubic interpolation given 4 points and a position t in [0, 1]
144+ template <typename scalar_t >
145+ inline scalar_t
146+ cubic_interp1d (scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) {
147+ scalar_t coeffs[4 ];
148+ get_cubic_upsample_coefficients<scalar_t >(coeffs, t);
149+
150+ return x0 * coeffs[0 ] + x1 * coeffs[1 ] + x2 * coeffs[2 ] + x3 * coeffs[3 ];
151+ }
152+
153+ // Argument checking and output tensor resizing for grid_sampler_2d
154+ Error check_grid_sampler_2d_args_and_resize_out (
155+ const Tensor& input,
156+ const Tensor& grid,
157+ Tensor& out);
158+
159+ } // namespace executor
160+ } // namespace torch
0 commit comments