@@ -116,7 +116,7 @@ static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
116116 C10_CUDA_KERNEL_LAUNCH_CHECK ();
117117}
118118
119- template <bool is_scatter_like, typename scalar_t >
119+ template <bool is_scatter_like, typename scalar_t , typename index_t >
120120struct _cuda_scatter_gather_internal_kernel {
121121 template <typename func_t >
122122 void operator () (
@@ -128,7 +128,7 @@ struct _cuda_scatter_gather_internal_kernel {
128128 ) {
129129 if (!iter.can_use_32bit_indexing ()) {
130130 for (auto & sub_iter : iter.with_32bit_indexing ()) {
131- _cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t >()(
131+ _cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t , index_t >()(
132132 sub_iter, index_size, index_stride, numel, f
133133 );
134134 }
@@ -151,15 +151,15 @@ struct _cuda_scatter_gather_internal_kernel {
151151 auto inp_stride_bytes = index_stride * element_size;
152152 auto out_stride_bytes = iter.strides (0 )[1 ];
153153 if (iter.numel () == 0 ) return ;
154- at::native::vectorized_gather_kernel_launch<alignment>(self_ptr, src_ptr, (int64_t *)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
154+ at::native::vectorized_gather_kernel_launch<alignment, index_t >(self_ptr, src_ptr, (index_t *)index_ptr, num_ind, slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes);
155155 return ;
156156 }
157157 }
158158 auto offset_calc = make_offset_calculator<3 >(iter);
159159 auto loop = [=]C10_DEVICE (int i) {
160160 auto offsets = offset_calc.get (i);
161161
162- int64_t idx_dim = *(int64_t *)(index_ptr + offsets[2 ]);
162+ int64_t idx_dim = *(index_t *)(index_ptr + offsets[2 ]);
163163 CUDA_KERNEL_ASSERT (idx_dim >= 0 && idx_dim < index_size
164164 && " scatter gather kernel index out of bounds" );
165165
@@ -229,9 +229,11 @@ struct cuda_scatter_gather_base_kernel {
229229 using dtype = typename std::conditional<cast_to_opaque,
230230 OpaqueType<sizeof (scalar_t )>, scalar_t >::type;
231231
232- _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
233- iter, index_size, index_stride, self.numel (), f
234- );
232+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_gather_base_kernel_func" , [&] () {
233+ _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t >()(
234+ iter, index_size, index_stride, self.numel (), f
235+ );
236+ });
235237 }
236238 );
237239 }
@@ -279,19 +281,40 @@ struct cuda_scatter_gather_base_kernel {
279281 auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
280282 auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
281283
282-
283- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 (
284- at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16 ,
285- iter. dtype (),
286- " cuda_scatter_gather_base_kernel_func " , [&] {
284+ if (self. is_quantized ()) {
285+ TORCH_CHECK (
286+ self. qscheme () == kPerTensorAffine ,
287+ " Only per_tensor quantized quantized tensors are supported by gather. " )
288+ AT_DISPATCH_QINT_TYPES (iter. dtype (), " gather_quant_cuda " , [&] {
287289 using dtype = typename std::conditional<cast_to_opaque,
288- OpaqueType<sizeof (scalar_t )>, scalar_t >::type;
289-
290- _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
291- iter, index_size, index_stride, self.numel (), f
292- );
293- }
294- );
290+ OpaqueType<sizeof (scalar_t )>, scalar_t >::type;
291+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_gather_base_kernel_func" , [&] () {
292+ _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t >()(
293+ iter, index_size, index_stride, self.numel (), f
294+ );
295+ });
296+ });
297+ } else {
298+ AT_DISPATCH_V2 (
299+ iter.dtype (),
300+ " gather_cuda" ,
301+ AT_WRAP ([&] {
302+ using dtype = typename std::conditional<cast_to_opaque,
303+ OpaqueType<sizeof (scalar_t )>, scalar_t >::type;
304+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_gather_base_kernel_func" , [&] () {
305+ _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t >()(
306+ iter, index_size, index_stride, self.numel (), f
307+ );
308+ });
309+ }),
310+ AT_EXPAND (AT_ALL_TYPES_AND_COMPLEX),
311+ AT_EXPAND (AT_BAREBONES_UNSIGNED_TYPES),
312+ AT_EXPAND (AT_FLOAT8_TYPES),
313+ kComplexHalf ,
314+ kHalf ,
315+ kBool ,
316+ kBFloat16 );
317+ }
295318 }
296319
297320 template <typename func_t >
@@ -338,23 +361,24 @@ struct cuda_scatter_gather_base_kernel {
338361 auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
339362 auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
340363
341-
342364 AT_DISPATCH_ALL_TYPES_AND2 (
343365 at::ScalarType::Half, at::ScalarType::BFloat16,
344366 iter.dtype (),
345367 " cuda_scatter_gather_base_kernel_func" , [&] {
346368 using dtype = typename std::conditional<cast_to_opaque,
347369 OpaqueType<sizeof (scalar_t )>, scalar_t >::type;
348370
349- _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
350- iter, index_size, index_stride, self.numel (), f
351- );
371+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_gather_base_kernel_func" , [&] () {
372+ _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype, index_t >()(
373+ iter, index_size, index_stride, self.numel (), f
374+ );
375+ });
352376 }
353377 );
354378 }
355379}; // struct cuda_scatter_gather_base_kernel
356380
357- template <typename scalar_t >
381+ template <typename scalar_t , typename index_t >
358382struct _cuda_scatter_fill_internal_kernel {
359383 template <typename func_t >
360384 void operator ()(
@@ -367,7 +391,7 @@ struct _cuda_scatter_fill_internal_kernel {
367391 ) {
368392 if (!iter.can_use_32bit_indexing ()) {
369393 for (auto & sub_iter : iter.with_32bit_indexing ()) {
370- _cuda_scatter_fill_internal_kernel<scalar_t >()(
394+ _cuda_scatter_fill_internal_kernel<scalar_t , index_t >()(
371395 sub_iter, src_val, index_size, index_stride, numel, f
372396 );
373397 }
@@ -381,7 +405,7 @@ struct _cuda_scatter_fill_internal_kernel {
381405 auto loop = [=]C10_DEVICE (int i) {
382406 auto offsets = offset_calc.get (i);
383407
384- int64_t idx_dim = *(int64_t *)(index_ptr + offsets[1 ]);
408+ int64_t idx_dim = *(index_t *)(index_ptr + offsets[1 ]);
385409 CUDA_KERNEL_ASSERT (idx_dim >= 0 && idx_dim < index_size
386410 && " index out of bounds"
387411 );
@@ -437,9 +461,11 @@ struct cuda_scatter_fill_base_kernel {
437461 auto src_scalar_val = src.to <scalar_t >();
438462 auto src_val = *(dtype*)&src_scalar_val;
439463
440- _cuda_scatter_fill_internal_kernel<dtype>()(
441- iter, src_val, index_size, index_stride, self.numel (), f
442- );
464+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_fill_base_kernel_func" , [&] () {
465+ _cuda_scatter_fill_internal_kernel<dtype, index_t >()(
466+ iter, src_val, index_size, index_stride, self.numel (), f
467+ );
468+ });
443469 }
444470 );
445471 }
@@ -480,9 +506,11 @@ struct cuda_scatter_fill_base_kernel {
480506 auto src_scalar_val = src.to <scalar_t >();
481507 auto src_val = *(dtype*)&src_scalar_val;
482508
483- _cuda_scatter_fill_internal_kernel<dtype>()(
484- iter, src_val, index_size, index_stride, self.numel (), f
485- );
509+ AT_DISPATCH_INDEX_TYPES (index.scalar_type (), " cuda_scatter_fill_base_kernel_reduce_multiply" , [&] () {
510+ _cuda_scatter_fill_internal_kernel<dtype, index_t >()(
511+ iter, src_val, index_size, index_stride, self.numel (), f
512+ );
513+ });
486514 }
487515 );
488516 }
0 commit comments