Skip to content

Commit 8944db4

Browse files
jgibson2John GibsonCopilot
authored
grid_sampler_2d_out Portable Kernel Implementation (#16051)
### Summary Adds grid_sample_2d_out portable kernels. Fixes #11328 Release notes: ops & kernels ### Test plan I was unable to get the internal tests using `torch.et_test` to build. However, I did add a fairly comprehensive test suite in kernels/portable/test/op_grid_sampler_2d_test.py which compares the exported operation results to torch.nn.functional.grid_sample operations. Note that there are _some_ differences in how the implementations handle nan, inf, and -inf values; however, all the tests using real numbers pass. --------- Co-authored-by: John Gibson <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 18c1c5b commit 8944db4

File tree

13 files changed

+1998
-0
lines changed

13 files changed

+1998
-0
lines changed

kernels/portable/cpu/op_grid_sampler_2d.cpp

Lines changed: 480 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/grid_sampler_2d_util.h>
10+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
Error check_grid_sampler_2d_args_and_resize_out(
16+
const Tensor& input,
17+
const Tensor& grid,
18+
Tensor& out) {
19+
// Input must be 4D (N, C, H, W)
20+
ET_CHECK_OR_RETURN_ERROR(
21+
input.dim() == 4,
22+
InvalidArgument,
23+
"Input must be 4D, got %zu dimensions",
24+
static_cast<size_t>(input.dim()));
25+
26+
ET_CHECK_OR_RETURN_ERROR(
27+
tensor_is_default_dim_order(input),
28+
InvalidArgument,
29+
"Input must be in NCHW format");
30+
31+
// Grid must be 4D (N, H_out, W_out, 2)
32+
ET_CHECK_OR_RETURN_ERROR(
33+
grid.dim() == 4,
34+
InvalidArgument,
35+
"Grid must be 4D, got %zu dimensions",
36+
static_cast<size_t>(grid.dim()));
37+
38+
ET_CHECK_OR_RETURN_ERROR(
39+
grid.size(3) == 2,
40+
InvalidArgument,
41+
"Grid last dimension must be 2, got %ld",
42+
static_cast<long>(grid.size(3)));
43+
44+
// Batch sizes must match
45+
ET_CHECK_OR_RETURN_ERROR(
46+
input.size(0) == grid.size(0),
47+
InvalidArgument,
48+
"Input and grid batch sizes must match, got input=%ld, grid=%ld",
49+
static_cast<long>(input.size(0)),
50+
static_cast<long>(grid.size(0)));
51+
52+
// Input and grid must have same dtype
53+
ET_CHECK_OR_RETURN_ERROR(
54+
tensors_have_same_dtype(input, grid),
55+
InvalidArgument,
56+
"Input and grid must have same dtype");
57+
58+
// Input and output must have the same dtype
59+
ET_CHECK_OR_RETURN_ERROR(
60+
tensors_have_same_dtype(input, out),
61+
InvalidArgument,
62+
"Input and output must have the same dtype");
63+
64+
// Resize output tensor to [N, C, H_out, W_out]
65+
std::array<exec_aten::SizesType, 4> out_sizes = {
66+
static_cast<exec_aten::SizesType>(input.size(0)),
67+
static_cast<exec_aten::SizesType>(input.size(1)),
68+
static_cast<exec_aten::SizesType>(grid.size(1)),
69+
static_cast<exec_aten::SizesType>(grid.size(2))};
70+
71+
Error err = resize_tensor(out, {out_sizes.data(), 4});
72+
ET_CHECK_OR_RETURN_ERROR(
73+
err == Error::Ok, InvalidArgument, "Failed to resize output tensor");
74+
75+
return Error::Ok;
76+
}
77+
78+
} // namespace executor
79+
} // namespace torch
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
// Ported from aten/src/ATen/native/GridSampler.h
19+
// note that these need to be in the SAME ORDER as the enum in GridSampler.h
20+
// as they are mapped to integer values (0, 1, 2) in this order
21+
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
22+
enum class GridSamplerPadding { Zeros, Border, Reflection };
23+
24+
// Ported from aten/src/ATen/native/GridSampler.h
25+
// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value,
26+
// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5).
27+
// if align_corners: -1 and +1 get sent to the centers of the corner pixels
28+
// -1 --> 0
29+
// +1 --> (size - 1)
30+
// scale_factor = (size - 1) / 2
31+
// if not align_corners: -1 and +1 get sent to the image edges
32+
// -1 --> -0.5
33+
// +1 --> (size - 1) + 0.5 == size - 0.5
34+
// scale_factor = size / 2
35+
template <typename scalar_t>
36+
inline scalar_t
37+
grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) {
38+
if (align_corners) {
39+
// unnormalize coord from [-1, 1] to [0, size - 1]
40+
return ((coord + 1) / 2) * (size - 1);
41+
} else {
42+
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
43+
return ((coord + 1) * size - 1) / 2;
44+
}
45+
}
46+
47+
// Ported from aten/src/ATen/native/GridSampler.h
48+
// Clips coordinates to between 0 and clip_limit - 1
49+
template <typename scalar_t>
50+
inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) {
51+
return std::min(
52+
static_cast<scalar_t>(clip_limit - 1),
53+
std::max(in, static_cast<scalar_t>(0)));
54+
}
55+
56+
// Ported from aten/src/ATen/native/GridSampler.h
57+
// Reflects coordinates until they fall between low and high (inclusive).
58+
// The bounds are passed as twice their value so that half-integer values
59+
// can be represented as ints.
60+
template <typename scalar_t>
61+
inline scalar_t
62+
reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) {
63+
if (twice_low == twice_high) {
64+
return static_cast<scalar_t>(0);
65+
}
66+
scalar_t min = static_cast<scalar_t>(twice_low) / 2;
67+
scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
68+
in = std::fabs(in - min);
69+
// `fmod` returns same sign as `in`, which is positive after the `fabs` above.
70+
scalar_t extra = std::fmod(in, span);
71+
int flips = static_cast<int>(std::floor(in / span));
72+
if (flips % 2 == 0) {
73+
return extra + min;
74+
} else {
75+
return span - extra + min;
76+
}
77+
}
78+
79+
// Ported from aten/src/ATen/native/GridSampler.h
80+
// Computes the pixel source index value for a grid coordinate
81+
template <typename scalar_t>
82+
inline scalar_t grid_sampler_compute_source_index(
83+
scalar_t coord,
84+
int64_t size,
85+
GridSamplerPadding padding_mode,
86+
bool align_corners) {
87+
coord = grid_sampler_unnormalize(coord, size, align_corners);
88+
if (padding_mode == GridSamplerPadding::Border) {
89+
// clip coordinates to image borders
90+
coord = clip_coordinates(coord, size);
91+
} else if (padding_mode == GridSamplerPadding::Reflection) {
92+
// reflect coordinates by image borders
93+
if (align_corners) {
94+
coord = reflect_coordinates(coord, 0, 2 * (size - 1));
95+
} else {
96+
coord = reflect_coordinates(coord, -1, 2 * size - 1);
97+
}
98+
coord = clip_coordinates(coord, size);
99+
}
100+
return coord;
101+
}
102+
103+
// Ported from aten/src/ATen/native/GridSampler.h
104+
// Check if coordinates are within bounds [0, limit-1]
105+
template <typename scalar_t>
106+
inline bool within_bounds_2d(scalar_t h, scalar_t w, int64_t H, int64_t W) {
107+
return h >= 0 && h < H && w >= 0 && w < W;
108+
}
109+
110+
// Ported from aten/src/ATen/native/UpSample.h
111+
// Cubic convolution function 1 (for points within 1 unit of the point)
112+
template <typename scalar_t>
113+
inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
114+
return ((A + 2) * x - (A + 3)) * x * x + 1;
115+
}
116+
117+
// Ported from aten/src/ATen/native/UpSample.h
118+
// Cubic convolution function 2 (for points between 1 and 2 units from the
119+
// point)
120+
template <typename scalar_t>
121+
inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
122+
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
123+
}
124+
125+
// Ported from aten/src/ATen/native/UpSample.h
126+
// Computes the 4 cubic interpolation coefficients for a given position t in [0,
127+
// 1]
128+
template <typename scalar_t>
129+
inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) {
130+
// Standard bicubic interpolation uses alpha = -0.75
131+
scalar_t A = static_cast<scalar_t>(-0.75);
132+
133+
scalar_t x1 = t;
134+
coeffs[0] = cubic_convolution2<scalar_t>(x1 + static_cast<scalar_t>(1.0), A);
135+
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
136+
137+
scalar_t x2 = static_cast<scalar_t>(1.0) - t;
138+
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
139+
coeffs[3] = cubic_convolution2<scalar_t>(x2 + static_cast<scalar_t>(1.0), A);
140+
}
141+
142+
// Ported from aten/src/ATen/native/UpSample.h
143+
// Performs 1D cubic interpolation given 4 points and a position t in [0, 1]
144+
template <typename scalar_t>
145+
inline scalar_t
146+
cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) {
147+
scalar_t coeffs[4];
148+
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
149+
150+
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
151+
}
152+
153+
// Argument checking and output tensor resizing for grid_sampler_2d
154+
Error check_grid_sampler_2d_args_and_resize_out(
155+
const Tensor& input,
156+
const Tensor& grid,
157+
Tensor& out);
158+
159+
} // namespace executor
160+
} // namespace torch

kernels/portable/cpu/util/targets.bzl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_common_targets():
3636
"//executorch/kernels/portable/cpu/util:elementwise_util",
3737
"//executorch/kernels/portable/cpu/util:upsample_util",
3838
"//executorch/kernels/portable/cpu/util:vectorized_math",
39+
"//executorch/kernels/portable/cpu/util:grid_sampler_2d_util",
3940
],
4041
visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"],
4142
)
@@ -342,6 +343,16 @@ def define_common_targets():
342343
],
343344
)
344345

346+
runtime.cxx_library(
347+
name = "grid_sampler_2d_util",
348+
srcs = ["grid_sampler_2d_util.cpp"],
349+
exported_headers = ["grid_sampler_2d_util.h"],
350+
deps = [
351+
"//executorch/runtime/kernel:kernel_includes",
352+
],
353+
visibility = ["//executorch/kernels/portable/cpu/..."],
354+
)
355+
345356
# Utility functions that can be used by operators that perform reduction
346357
for aten_mode in get_aten_mode_options():
347358
suffix = "_aten" if aten_mode else ""

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,11 @@
427427
- arg_meta: null
428428
kernel_name: torch::executor::glu_out
429429

430+
- op: grid_sampler_2d.out
431+
kernels:
432+
- arg_meta: null
433+
kernel_name: torch::executor::grid_sampler_2d_out
434+
430435
- op: gt.Scalar_out
431436
kernels:
432437
- arg_meta: null

kernels/portable/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ runtime.cxx_library(
1919
],
2020
deps = [
2121
"//executorch/extension/aten_util:aten_bridge",
22+
"//executorch/kernels/portable/cpu:op_grid_sampler_2d",
2223
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d",
2324
"//executorch/kernels/portable/cpu:op_upsample_bilinear2d_aa",
2425
"//executorch/kernels/portable/cpu:op_upsample_nearest2d",

0 commit comments

Comments
 (0)