Skip to content

Commit 6962b97

Browse files
committed
Merge remote-tracking branch 'upstream/main' into port_event_less
2 parents f1cc530 + f8b1ee9 commit 6962b97

File tree

9 files changed

+398
-31
lines changed

9 files changed

+398
-31
lines changed

src/ATen/native/xpu/sycl/Loops.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) {
620620
gpu_kernel_impl_nocast(iter, f);
621621
}
622622

623-
template <typename func_t>
623+
template <typename func_t, bool enable_broadcast_vec = true>
624624
void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
625625
for (int arg = 0; arg < iter.ntensors(); arg++) {
626626
TORCH_INTERNAL_ASSERT(
@@ -637,12 +637,14 @@ void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
637637

638638
if (!iter.can_use_32bit_indexing()) {
639639
for (auto& sub_iter : iter.with_32bit_indexing()) {
640-
gpu_kernel(sub_iter, f);
640+
// Broadcasting vectorization is disabled for sub-iterators to prevent
641+
// potential output offset calculation issues.
642+
gpu_kernel<func_t, false>(sub_iter, f);
641643
}
642644
return;
643645
}
644646

645-
gpu_kernel_impl(iter, f);
647+
gpu_kernel_impl<func_t, enable_broadcast_vec>(iter, f);
646648
}
647649

648650
template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>

src/ATen/native/xpu/sycl/LossNLLKernel.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <ATen/AccumulateType.h>
12
#include <ATen/Functions.h>
23
#include <ATen/TensorUtils.h>
34
#include <ATen/core/Reduction.h>
@@ -126,7 +127,7 @@ struct NllLossForwardReduce1DKernelFunctor {
126127
int64_t reduction;
127128
};
128129

129-
template <typename scalar_t, typename index_t>
130+
template <typename scalar_t, typename index_t, typename accscalar_t>
130131
struct NllLossForwardReduce2DKernelFunctor
131132
: public __SYCL_KER_CONFIG_CONVENTION__ {
132133
void operator()(sycl::nd_item<1> item_id) const {
@@ -136,17 +137,18 @@ struct NllLossForwardReduce2DKernelFunctor
136137
auto total_weight_ptr = total_weight_data;
137138
auto output_ptr = output_data;
138139
int64_t local_id = item_id.get_local_id(0);
139-
local_output_acc[local_id] = 0.0;
140-
local_total_weight_acc[local_id] = 0.0;
140+
local_output_acc[local_id] = accscalar_t(0);
141+
local_total_weight_acc[local_id] = accscalar_t(0);
141142
for (int i = local_id; i < batch_size; i += local_size) {
142143
int cur_target = target_ptr[i];
143144
if (cur_target != ignore_index) {
144145
scalar_t cur_weight =
145146
has_weight ? weight_ptr[cur_target] : static_cast<scalar_t>(1.0f);
146-
local_total_weight_acc[local_id] += cur_weight;
147+
local_total_weight_acc[local_id] +=
148+
static_cast<accscalar_t>(cur_weight);
147149
local_output_acc[local_id] -=
148-
static_cast<scalar_t>(input_ptr[i * n_target + cur_target]) *
149-
static_cast<scalar_t>(cur_weight);
150+
static_cast<accscalar_t>(input_ptr[i * n_target + cur_target]) *
151+
static_cast<accscalar_t>(cur_weight);
150152
}
151153
}
152154

@@ -161,11 +163,13 @@ struct NllLossForwardReduce2DKernelFunctor
161163
}
162164
item_id.barrier(sycl_global_and_local_fence);
163165

164-
output_ptr[0] = local_output_acc[0];
165-
total_weight_ptr[0] = local_total_weight_acc[0];
166166
if (reduction == at::Reduction::Mean) {
167-
output_ptr[0] /= total_weight_ptr[0];
167+
output_ptr[0] = static_cast<scalar_t>(
168+
local_output_acc[0] / local_total_weight_acc[0]);
169+
} else {
170+
output_ptr[0] = static_cast<scalar_t>(local_output_acc[0]);
168171
}
172+
total_weight_ptr[0] = static_cast<scalar_t>(local_total_weight_acc[0]);
169173
}
170174
NllLossForwardReduce2DKernelFunctor(
171175
scalar_t* input_data_,
@@ -192,8 +196,8 @@ struct NllLossForwardReduce2DKernelFunctor
192196
reduction(reduction_) {}
193197

194198
void sycl_ker_config_convention(sycl::handler& cgh) {
195-
local_output_acc = sycl_local_acc_t<scalar_t>(local_size, cgh);
196-
local_total_weight_acc = sycl_local_acc_t<scalar_t>(local_size, cgh);
199+
local_output_acc = sycl_local_acc_t<accscalar_t>(local_size, cgh);
200+
local_total_weight_acc = sycl_local_acc_t<accscalar_t>(local_size, cgh);
197201
}
198202

199203
private:
@@ -207,8 +211,8 @@ struct NllLossForwardReduce2DKernelFunctor
207211
int64_t local_size;
208212
int64_t ignore_index;
209213
int n_target;
210-
sycl_local_acc_t<scalar_t> local_output_acc;
211-
sycl_local_acc_t<scalar_t> local_total_weight_acc;
214+
sycl_local_acc_t<accscalar_t> local_output_acc;
215+
sycl_local_acc_t<accscalar_t> local_total_weight_acc;
212216
int64_t reduction;
213217
};
214218

@@ -309,8 +313,9 @@ void nll_loss_forward_template(
309313

310314
sycl_kernel_submit(sycl::range<1>(local_size), queue, kfn);
311315
} else if (input_cont.dim() == 2) {
316+
using accscalar_t = at::acc_type<scalar_t, true>;
312317
using NllLossForwardReduce2DKernel =
313-
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t>;
318+
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t, accscalar_t>;
314319

315320
int64_t batch_size = input.size(0);
316321
int n_target = input.size(1);
@@ -322,7 +327,7 @@ void nll_loss_forward_template(
322327
auto target_data = _target_data;
323328
auto total_weight_data = _total_weight_data;
324329
auto output_data = _output_data;
325-
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t> kfn(
330+
NllLossForwardReduce2DKernelFunctor<scalar_t, index_t, accscalar_t> kfn(
326331
input_data,
327332
target_data,
328333
weight_data,

src/xccl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
file(GLOB xccl_h "*.hpp")
44
file(GLOB xccl_cpp "*.cpp")
5+
list(REMOVE_ITEM xccl_cpp "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")
56

67
list(APPEND ATen_XPU_XCCL_SRCS ${xccl_cpp})
8+
list(APPEND ATen_XPU_SYCL_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/NanCheck_XPU.cpp")
79

810
set(ATen_XPU_XCCL_SRCS ${ATen_XPU_XCCL_SRCS} PARENT_SCOPE)
11+
set(ATen_XPU_SYCL_SRCS ${ATen_XPU_SYCL_SRCS} PARENT_SCOPE)
912

1013
# Why copy the header file to the build directory?
1114
# We want register XCCL backend to PyTorch c10d in torch/csrc/distributed/c10d/init.cpp#L27-L29.

src/xccl/NanCheck_XPU.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/NumericUtils.h>
3+
#include <ATen/native/xpu/sycl/MemoryAccessUtils.h>
4+
#include <ATen/xpu/XPUContext.h>
5+
#include <comm/SYCLContext.h>
6+
#include <stdint.h>
7+
#include <torch/torch.h>
8+
#include <xccl/NanCheck_XPU.hpp>
9+
#include <algorithm>
10+
11+
namespace c10d {
12+
13+
using BytePack = at::native::memory::aligned_vector<uint64_t, 2>;
14+
15+
template <typename T, int EltPerPack>
16+
struct CheckBytePack {
17+
static void check(BytePack* tmp) {
18+
T* data = (T*)tmp;
19+
#pragma unroll 8
20+
for (int i = 0; i < EltPerPack; i++) {
21+
if (at::_isnan(data[i]))
22+
assert(0);
23+
}
24+
}
25+
};
26+
27+
template <typename T>
28+
struct CheckBytePack<T, /*EltPerPack*/ 2> {
29+
static void check(BytePack* tmp) {
30+
T* data = (T*)tmp;
31+
if (at::_isnan(data[0]) || at::_isnan(data[1]))
32+
assert(0);
33+
}
34+
};
35+
36+
template <typename T>
37+
struct CheckBytePack<T, /*EltPerPack*/ 4> {
38+
static void check(BytePack* tmp) {
39+
T* data = (T*)tmp;
40+
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
41+
at::_isnan(data[3]))
42+
assert(0);
43+
}
44+
};
45+
46+
template <typename T>
47+
struct CheckBytePack<T, /*EltPerPack*/ 8> {
48+
static void check(BytePack* tmp) {
49+
T* data = (T*)tmp;
50+
if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) ||
51+
at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) ||
52+
at::_isnan(data[6]) || at::_isnan(data[7])) {
53+
assert(0);
54+
}
55+
}
56+
};
57+
58+
template <typename T>
59+
struct HasNanFP8x8 {
60+
static bool check(uint64_t fp8x8) = delete;
61+
/*
62+
{
63+
// `static_assert` in template definition requires c++23 onwards.
64+
// But the error message still applies if you find yourself here.
65+
static_assert(
66+
false,
67+
"You should never call this template definition because it is empty. You "
68+
"can follow the example of Float8_e4m3fn below to implement the check for
69+
" "your new datatype."
70+
);
71+
}
72+
*/
73+
};
74+
75+
template <>
76+
struct HasNanFP8x8<c10::Float8_e4m3fn> {
77+
static bool check(uint64_t fp8x8) {
78+
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
79+
auto incremented = t + 0x0101010101010101ULL;
80+
auto overflow = incremented & 0x8080808080808080ULL;
81+
return overflow != 0;
82+
}
83+
};
84+
85+
template <>
86+
struct HasNanFP8x8<c10::Float8_e5m2> {
87+
static bool check(uint64_t fp8x8) {
88+
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
89+
auto incremented = t + 0x0303030303030303ULL;
90+
auto overflow = incremented & 0x8080808080808080ULL;
91+
return overflow != 0;
92+
}
93+
};
94+
95+
template <typename T>
96+
struct CheckBytePack<T, /*EltPerPack*/ 16> {
97+
static void check(BytePack* tmp) {
98+
if (HasNanFP8x8<T>::check(tmp->val[0]) ||
99+
HasNanFP8x8<T>::check(tmp->val[1]))
100+
assert(0);
101+
}
102+
};
103+
104+
#define UNROLL 8
105+
106+
template <typename T>
107+
void checkChunk(BytePack* ptr, int nWorkers) {
108+
BytePack tmp[UNROLL];
109+
110+
#pragma unroll 8
111+
for (int j = 0; j < UNROLL; j++) {
112+
tmp[j] = ptr[nWorkers * j];
113+
}
114+
// Then check each BytePack in the tmp buffer
115+
#pragma unroll 8
116+
for (int j = 0; j < UNROLL; j++) {
117+
CheckBytePack<T, sizeof(BytePack) / sizeof(T)>::check(tmp + j);
118+
}
119+
// Note: we separate the check from the load for efficient loading
120+
}
121+
122+
// Align address of `ptr` up, to the alignment of `T`
123+
#define ALIGN_UP(ptr, T) \
124+
(((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T))
125+
126+
template <typename T>
127+
struct checkForNaN {
128+
void operator()(sycl::nd_item<1> item) const {
129+
constexpr int EltPerPack = sizeof(BytePack) / sizeof(T);
130+
131+
size_t offset = item.get_global_id(0);
132+
133+
// Align input address up to BytePack in case it is not
134+
T* ptrAlign = (T*)ALIGN_UP(data, BytePack);
135+
size_t preProcElts =
136+
std::min<size_t>(static_cast<size_t>(ptrAlign - data), size);
137+
138+
size_t size_left = size;
139+
140+
if (offset < preProcElts) {
141+
if (at::_isnan(data[offset]))
142+
assert(0);
143+
}
144+
size_left -= preProcElts;
145+
146+
BytePack* ptr = (BytePack*)ptrAlign;
147+
size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack);
148+
size_t loopSize = item.get_global_range(0) * UNROLL;
149+
150+
for (; offset + loopSize <= sizeInBP; offset += loopSize) {
151+
checkChunk<T>(ptr + offset, item.get_global_range(0));
152+
}
153+
154+
for (; offset < sizeInBP; offset += item.get_global_range(0)) {
155+
BytePack tmp = ptr[offset];
156+
CheckBytePack<T, EltPerPack>::check(&tmp);
157+
}
158+
159+
if (item.get_local_id(0) < size_left % EltPerPack) {
160+
T* tailPtr = (T*)(ptr + sizeInBP);
161+
if (at::_isnan(tailPtr[item.get_local_id(0)]))
162+
assert(0);
163+
}
164+
}
165+
checkForNaN(T* data, size_t size) : data(data), size(size) {}
166+
167+
private:
168+
T* data;
169+
size_t size;
170+
};
171+
172+
template <typename T>
173+
void checkfornan_impl_xpu(
174+
const at::Tensor& tensor,
175+
at::xpu::XPUStream& stream) {
176+
// skip check for non float types
177+
if (!torch::is_floating_point(tensor)) {
178+
return;
179+
}
180+
181+
int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize<checkForNaN<T>>();
182+
183+
const size_t numThreadsPerBlock =
184+
std::min<size_t>(maxNumThreadsPerBlock, tensor.numel());
185+
186+
if (!(numThreadsPerBlock > 0)) {
187+
return;
188+
}
189+
190+
int64_t numBlocks =
191+
(tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock;
192+
auto global_range{numBlocks * numThreadsPerBlock};
193+
auto local_range{numThreadsPerBlock};
194+
195+
using Kernel = checkForNaN<T>;
196+
auto kfn = Kernel(tensor.data_ptr<T>(), tensor.numel());
197+
198+
sycl_kernel_submit(global_range, local_range, stream.queue(), kfn);
199+
}
200+
201+
// CHECK if a Tensor contains NAN in any of its element
202+
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) {
203+
AT_DISPATCH_FLOATING_TYPES_AND4(
204+
at::ScalarType::Half,
205+
at::ScalarType::BFloat16,
206+
at::ScalarType::Float8_e4m3fn,
207+
at::ScalarType::Float8_e5m2,
208+
tensor.scalar_type(),
209+
"checkForNaN_XPU",
210+
[&]() { checkfornan_impl_xpu<scalar_t>(tensor, stream); });
211+
}
212+
213+
} // namespace c10d

src/xccl/NanCheck_XPU.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#ifdef USE_C10D_XCCL
4+
5+
#include <ATen/ATen.h>
6+
#include <c10/xpu/XPUStream.h>
7+
8+
namespace c10d {
9+
10+
void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream);
11+
12+
} // namespace c10d
13+
14+
#endif // USE_C10D_XCCL

0 commit comments

Comments
 (0)