Skip to content

Commit f0c9b33

Browse files
isurufpytorchmergebot
authored andcommitted
Support more dtypes for input, indices in gather (pytorch#151822)
Pull Request resolved: pytorch#151822 Approved by: https://github.com/ngimel
1 parent 4c8dee7 commit f0c9b33

File tree

10 files changed

+97
-68
lines changed

10 files changed

+97
-68
lines changed

aten/src/ATen/native/ScatterGatherChecks.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ inline void scatter_gather_dtype_check(
1919
) {
2020
if (index.numel() != 0) {
2121
TORCH_CHECK(
22-
index.scalar_type() == at::ScalarType::Long,
23-
method_name, "(): Expected dtype int64 for index"
22+
index.scalar_type() == at::ScalarType::Long || index.scalar_type() == at::ScalarType::Int,
23+
method_name, "(): Expected dtype int32/int64 for index"
2424
);
2525
}
2626

aten/src/ATen/native/TensorAdvancedIndexing.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ TORCH_META_FUNC(gather)
175175
auto is_index_empty = index.numel() == 0;
176176
if (!is_index_empty) {
177177
TORCH_CHECK(
178-
index.scalar_type() == at::ScalarType::Long,
178+
index.scalar_type() == ScalarType::Long ||
179+
index.scalar_type() == ScalarType::Int,
179180
"gather",
180-
"(): Expected dtype int64 for index");
181+
"(): Expected dtype int32/int64 for index");
181182
}
182183
if (is_index_empty)
183184
return;

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,11 @@ template <bool is_scatter_like = true>
167167
struct cpu_scatter_gather_base_kernel {
168168
template <typename func_t>
169169
void operator()(const Tensor& self, int64_t dim,
170-
const Tensor& index, const Scalar& value,
170+
const Tensor& _index, const Scalar& value,
171171
const std::string& method_name, func_t& kernel_func) {
172172

173173
Tensor buffer;
174+
Tensor index = _index.to(ScalarType::Long);
174175
bool need_acc = isReducedFloatingType(self.scalar_type());
175176
create_acc_buffer(buffer, self, need_acc);
176177

@@ -263,10 +264,11 @@ struct cpu_scatter_gather_base_kernel {
263264

264265
template <typename func_t>
265266
void operator()(const Tensor& self, int64_t dim,
266-
const Tensor& index, const Tensor& src,
267+
const Tensor& _index, const Tensor& src,
267268
const std::string& method_name, func_t& kernel_func) {
268269

269270
Tensor buffer;
271+
Tensor index = _index.to(ScalarType::Long);
270272
bool need_acc = isReducedFloatingType(self.scalar_type());
271273
create_acc_buffer(buffer, self, need_acc);
272274

@@ -358,10 +360,11 @@ struct cpu_scatter_gather_base_kernel {
358360
}
359361

360362
void operator()(const Tensor& self, int64_t dim,
361-
const Tensor& index, const Tensor& src,
363+
const Tensor& _index, const Tensor& src,
362364
const std::string& method_name, ReduceMean& kernel_func) {
363365

364366
Tensor buffer;
367+
Tensor index = _index.to(ScalarType::Long);
365368
bool need_acc = isReducedFloatingType(self.scalar_type());
366369
create_acc_buffer(buffer, self, need_acc);
367370

@@ -453,9 +456,10 @@ struct cpu_scatter_gather_base_kernel {
453456
}
454457

455458
void operator()(const Tensor& self, int64_t dim,
456-
const Tensor& index, const Tensor& src,
459+
const Tensor& _index, const Tensor& src,
457460
const std::string& method_name, ReduceMaximum& kernel_func) {
458461
Tensor buffer;
462+
Tensor index = _index.to(ScalarType::Long);
459463
bool need_acc = isReducedFloatingType(self.scalar_type());
460464
create_acc_buffer(buffer, self, need_acc);
461465

@@ -547,10 +551,11 @@ struct cpu_scatter_gather_base_kernel {
547551
}
548552

549553
void operator()(const Tensor& self, int64_t dim,
550-
const Tensor& index, const Tensor& src,
554+
const Tensor& _index, const Tensor& src,
551555
const std::string& method_name, ReduceMinimum& kernel_func) {
552556

553557
Tensor buffer;
558+
Tensor index = _index.to(ScalarType::Long);
554559
bool need_acc = isReducedFloatingType(self.scalar_type());
555560
create_acc_buffer(buffer, self, need_acc);
556561

@@ -810,7 +815,8 @@ void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index,
810815
}
811816

812817
template <typename scalar_t>
813-
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& index, const Tensor& self) {
818+
void cpu_gather_expanded_index_kernel(const Tensor& result, const Tensor& _index, const Tensor& self) {
819+
Tensor index = _index.to(ScalarType::Long);
814820
const int64_t* index_data = index.const_data_ptr<int64_t>();
815821
scalar_t* result_data = result.data_ptr<scalar_t>();
816822
const scalar_t* self_data = self.const_data_ptr<scalar_t>();

aten/src/ATen/native/cuda/IndexKernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
8484
auto inp_stride_bytes = index_stride[0];
8585
auto out_stride_bytes = iter.strides(0)[1];
8686
if (iter.numel() == 0) return;
87-
at::native::vectorized_gather_kernel_launch<alignment>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
87+
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
8888
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
8989
return;
9090
}

aten/src/ATen/native/cuda/IndexKernelUtils.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#include <ATen/ceil_div.h>
88

99
namespace at::native {
10-
template <int Alignment>
11-
__global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
10+
template <int Alignment, typename index_t>
11+
__global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, int num_ind, int64_t slice_size, int64_t ind_dim_size, int64_t inp_stride, int64_t out_stride, bool allow_neg_indices) {
1212
int64_t ind = idx[blockIdx.x];
1313
if (allow_neg_indices) {
1414
ind = (ind < 0) ? ind + ind_dim_size : ind;
@@ -22,8 +22,8 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, int64_t * idx,
2222

2323

2424

25-
template <int64_t Alignment>
26-
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
25+
template <int64_t Alignment, typename index_t>
26+
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
2727
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices){
2828

2929
constexpr int64_t max_num_threads=256;
@@ -32,13 +32,15 @@ void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int
3232
static_cast<int64_t>(C10_WARP_SIZE));
3333
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
3434
auto block = std::min(max_num_threads, num_threads);
35-
vectorized_gather_kernel<Alignment><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
35+
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
3636
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
3737
C10_CUDA_KERNEL_LAUNCH_CHECK();
3838
}
3939

4040
// explicit template instantiation
41-
template void vectorized_gather_kernel_launch<16>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
41+
template void vectorized_gather_kernel_launch<16, int64_t>(char * out, char * inp, int64_t * idx, int num_ind, int64_t slice_size_in_bytes,
42+
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
43+
template void vectorized_gather_kernel_launch<16, int32_t>(char * out, char * inp, int32_t * idx, int num_ind, int64_t slice_size_in_bytes,
4244
int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes, bool allow_neg_indices);
4345

4446
}

aten/src/ATen/native/cuda/IndexKernelUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ inline bool fast_gather_kernel_eligible(const TensorIterator& iter, char * const
2626
get_alignment(static_cast<size_t>(iter.strides(0)[1])) == alignment;
2727
}
2828

29-
template <int64_t Alignment>
30-
void vectorized_gather_kernel_launch(char * out, char * inp, int64_t * idx, int num_ind,
29+
template <int64_t Alignment, typename index_t>
30+
void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int num_ind,
3131
int64_t slice_size_in_bytes, int64_t ind_dim_size, int64_t inp_stride_bytes, int64_t out_stride_bytes,
3232
bool allow_neg_indices=false);
3333

aten/src/ATen/native/cuda/ScatterGatherKernel.cu

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
116116
C10_CUDA_KERNEL_LAUNCH_CHECK();
117117
}
118118

119-
template <bool is_scatter_like, typename scalar_t>
119+
template <bool is_scatter_like, typename scalar_t, typename index_t>
120120
struct _cuda_scatter_gather_internal_kernel {
121121
template <typename func_t>
122122
void operator() (
@@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel {
128128
) {
129129
if (!iter.can_use_32bit_indexing()) {
130130
for (auto& sub_iter : iter.with_32bit_indexing()) {
131-
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
131+
_cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t, index_t>()(
132132
sub_iter, index_size, index_stride, numel, f
133133
);
134134
}
@@ -151,15 +151,15 @@ struct _cuda_scatter_gather_internal_kernel {
151151
auto inp_stride_bytes = index_stride * element_size;
152152
auto out_stride_bytes = iter.strides(0)[1];
153153
if (iter.numel() == 0) return;
154-
at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
154+
at::native::vectorized_gather_kernel_launch<alignment, index_t>(self_ptr, src_ptr, (index_t*)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
155155
return;
156156
}
157157
}
158158
auto offset_calc = make_offset_calculator<3>(iter);
159159
auto loop = [=]C10_DEVICE(int i) {
160160
auto offsets = offset_calc.get(i);
161161

162-
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
162+
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
163163
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
164164
&& "scatter gather kernel index out of bounds");
165165

@@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel {
229229
using dtype = typename std::conditional<cast_to_opaque,
230230
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
231231

232-
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
233-
iter, index_size, index_stride, self.numel(), f
234-
);
232+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
233+
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
234+
iter, index_size, index_stride, self.numel(), f
235+
);
236+
});
235237
}
236238
);
237239
}
@@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel {
279281
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
280282
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
281283

282-
283-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
284-
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
285-
iter.dtype(),
286-
"cuda_scatter_gather_base_kernel_func", [&] {
284+
if (self.is_quantized()) {
285+
TORCH_CHECK(
286+
self.qscheme() == kPerTensorAffine,
287+
"Only per_tensor quantized quantized tensors are supported by gather.")
288+
AT_DISPATCH_QINT_TYPES(iter.dtype(), "gather_quant_cuda", [&] {
287289
using dtype = typename std::conditional<cast_to_opaque,
288-
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
289-
290-
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
291-
iter, index_size, index_stride, self.numel(), f
292-
);
293-
}
294-
);
290+
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
291+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
292+
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
293+
iter, index_size, index_stride, self.numel(), f
294+
);
295+
});
296+
});
297+
} else {
298+
AT_DISPATCH_V2(
299+
iter.dtype(),
300+
"gather_cuda",
301+
AT_WRAP([&] {
302+
using dtype = typename std::conditional<cast_to_opaque,
303+
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
304+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
305+
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
306+
iter, index_size, index_stride, self.numel(), f
307+
);
308+
});
309+
}),
310+
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
311+
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
312+
AT_EXPAND(AT_FLOAT8_TYPES),
313+
kComplexHalf,
314+
kHalf,
315+
kBool,
316+
kBFloat16);
317+
}
295318
}
296319

297320
template <typename func_t>
@@ -338,23 +361,24 @@ struct cuda_scatter_gather_base_kernel {
338361
auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
339362
auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
340363

341-
342364
AT_DISPATCH_ALL_TYPES_AND2(
343365
at::ScalarType::Half, at::ScalarType::BFloat16,
344366
iter.dtype(),
345367
"cuda_scatter_gather_base_kernel_func", [&] {
346368
using dtype = typename std::conditional<cast_to_opaque,
347369
OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
348370

349-
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
350-
iter, index_size, index_stride, self.numel(), f
351-
);
371+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_gather_base_kernel_func", [&] () {
372+
_cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t>()(
373+
iter, index_size, index_stride, self.numel(), f
374+
);
375+
});
352376
}
353377
);
354378
}
355379
}; // struct cuda_scatter_gather_base_kernel
356380

357-
template <typename scalar_t>
381+
template <typename scalar_t, typename index_t>
358382
struct _cuda_scatter_fill_internal_kernel {
359383
template <typename func_t>
360384
void operator()(
@@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel {
367391
) {
368392
if (!iter.can_use_32bit_indexing()) {
369393
for (auto& sub_iter : iter.with_32bit_indexing()) {
370-
_cuda_scatter_fill_internal_kernel<scalar_t>()(
394+
_cuda_scatter_fill_internal_kernel<scalar_t, index_t>()(
371395
sub_iter, src_val, index_size, index_stride, numel, f
372396
);
373397
}
@@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel {
381405
auto loop = [=]C10_DEVICE(int i) {
382406
auto offsets = offset_calc.get(i);
383407

384-
int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
408+
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
385409
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
386410
&& "index out of bounds"
387411
);
@@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel {
437461
auto src_scalar_val = src.to<scalar_t>();
438462
auto src_val = *(dtype*)&src_scalar_val;
439463

440-
_cuda_scatter_fill_internal_kernel<dtype>()(
441-
iter, src_val, index_size, index_stride, self.numel(), f
442-
);
464+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_func", [&] () {
465+
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
466+
iter, src_val, index_size, index_stride, self.numel(), f
467+
);
468+
});
443469
}
444470
);
445471
}
@@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel {
480506
auto src_scalar_val = src.to<scalar_t>();
481507
auto src_val = *(dtype*)&src_scalar_val;
482508

483-
_cuda_scatter_fill_internal_kernel<dtype>()(
484-
iter, src_val, index_size, index_stride, self.numel(), f
485-
);
509+
AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "cuda_scatter_fill_base_kernel_reduce_multiply", [&] () {
510+
_cuda_scatter_fill_internal_kernel<dtype, index_t>()(
511+
iter, src_val, index_size, index_stride, self.numel(), f
512+
);
513+
});
486514
}
487515
);
488516
}

torch/_inductor/lowering.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3335,7 +3335,6 @@ def gather(x, dim, index, sparse_grad=False):
33353335
# Empty index case. Return an empty array with the same shape
33363336
return new_empty(x, index.get_size())
33373337

3338-
assert index.get_dtype() == torch.int64
33393338
size = x.get_size()
33403339
offset = len(size) == 0
33413340
dim = _validate_dim(x, dim, offset)

torch/_meta_registrations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5420,8 +5420,8 @@ def meta_gather(self, dim, index, sparse_grad=False):
54205420
is_index_empty = guard_size_oblivious(index.numel() == 0)
54215421
if not is_index_empty:
54225422
torch._check(
5423-
index.dtype == torch.long,
5424-
lambda: f"gather(): Expected dtype int64 for index, but got {index.dtype}",
5423+
index.dtype == torch.long or index.dtype == torch.int,
5424+
lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
54255425
)
54265426
gather_shape_check(self, wrapped_dim, index)
54275427
return self.new_empty(index.shape)
@@ -5460,8 +5460,8 @@ def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
54605460

54615461
if guard_size_oblivious(index.numel() != 0):
54625462
torch._check(
5463-
index.dtype == torch.long,
5464-
lambda: f"{method_name}(): Expected dtype int64 for index",
5463+
index.dtype == torch.long or index.dtype == torch.int,
5464+
lambda: f"{method_name}(): Expected dtype int32/int64 for index",
54655465
)
54665466

54675467
if src_opt is not None:

0 commit comments

Comments
 (0)