|
| 1 | +#include <ATen/Dispatch.h> |
| 2 | +#include <ATen/NumericUtils.h> |
| 3 | +#include <ATen/native/xpu/sycl/MemoryAccessUtils.h> |
| 4 | +#include <ATen/xpu/XPUContext.h> |
| 5 | +#include <comm/SYCLContext.h> |
| 6 | +#include <stdint.h> |
| 7 | +#include <torch/torch.h> |
| 8 | +#include <xccl/NanCheck_XPU.hpp> |
| 9 | +#include <algorithm> |
| 10 | + |
| 11 | +namespace c10d { |
| 12 | + |
| 13 | +using BytePack = at::native::memory::aligned_vector<uint64_t, 2>; |
| 14 | + |
| 15 | +template <typename T, int EltPerPack> |
| 16 | +struct CheckBytePack { |
| 17 | + static void check(BytePack* tmp) { |
| 18 | + T* data = (T*)tmp; |
| 19 | +#pragma unroll 8 |
| 20 | + for (int i = 0; i < EltPerPack; i++) { |
| 21 | + if (at::_isnan(data[i])) |
| 22 | + assert(0); |
| 23 | + } |
| 24 | + } |
| 25 | +}; |
| 26 | + |
| 27 | +template <typename T> |
| 28 | +struct CheckBytePack<T, /*EltPerPack*/ 2> { |
| 29 | + static void check(BytePack* tmp) { |
| 30 | + T* data = (T*)tmp; |
| 31 | + if (at::_isnan(data[0]) || at::_isnan(data[1])) |
| 32 | + assert(0); |
| 33 | + } |
| 34 | +}; |
| 35 | + |
| 36 | +template <typename T> |
| 37 | +struct CheckBytePack<T, /*EltPerPack*/ 4> { |
| 38 | + static void check(BytePack* tmp) { |
| 39 | + T* data = (T*)tmp; |
| 40 | + if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || |
| 41 | + at::_isnan(data[3])) |
| 42 | + assert(0); |
| 43 | + } |
| 44 | +}; |
| 45 | + |
| 46 | +template <typename T> |
| 47 | +struct CheckBytePack<T, /*EltPerPack*/ 8> { |
| 48 | + static void check(BytePack* tmp) { |
| 49 | + T* data = (T*)tmp; |
| 50 | + if (at::_isnan(data[0]) || at::_isnan(data[1]) || at::_isnan(data[2]) || |
| 51 | + at::_isnan(data[3]) || at::_isnan(data[4]) || at::_isnan(data[5]) || |
| 52 | + at::_isnan(data[6]) || at::_isnan(data[7])) { |
| 53 | + assert(0); |
| 54 | + } |
| 55 | + } |
| 56 | +}; |
| 57 | + |
| 58 | +template <typename T> |
| 59 | +struct HasNanFP8x8 { |
| 60 | + static bool check(uint64_t fp8x8) = delete; |
| 61 | + /* |
| 62 | + { |
| 63 | + // `static_assert` in template definition requires c++23 onwards. |
| 64 | + // But the error message still applies if you find yourself here. |
| 65 | + static_assert( |
| 66 | + false, |
| 67 | + "You should never call this template definition because it is empty. You " |
| 68 | + "can follow the example of Float8_e4m3fn below to implement the check for |
| 69 | + " "your new datatype." |
| 70 | + ); |
| 71 | + } |
| 72 | + */ |
| 73 | +}; |
| 74 | + |
| 75 | +template <> |
| 76 | +struct HasNanFP8x8<c10::Float8_e4m3fn> { |
| 77 | + static bool check(uint64_t fp8x8) { |
| 78 | + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; |
| 79 | + auto incremented = t + 0x0101010101010101ULL; |
| 80 | + auto overflow = incremented & 0x8080808080808080ULL; |
| 81 | + return overflow != 0; |
| 82 | + } |
| 83 | +}; |
| 84 | + |
| 85 | +template <> |
| 86 | +struct HasNanFP8x8<c10::Float8_e5m2> { |
| 87 | + static bool check(uint64_t fp8x8) { |
| 88 | + auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL; |
| 89 | + auto incremented = t + 0x0303030303030303ULL; |
| 90 | + auto overflow = incremented & 0x8080808080808080ULL; |
| 91 | + return overflow != 0; |
| 92 | + } |
| 93 | +}; |
| 94 | + |
| 95 | +template <typename T> |
| 96 | +struct CheckBytePack<T, /*EltPerPack*/ 16> { |
| 97 | + static void check(BytePack* tmp) { |
| 98 | + if (HasNanFP8x8<T>::check(tmp->val[0]) || |
| 99 | + HasNanFP8x8<T>::check(tmp->val[1])) |
| 100 | + assert(0); |
| 101 | + } |
| 102 | +}; |
| 103 | + |
| 104 | +#define UNROLL 8 |
| 105 | + |
| 106 | +template <typename T> |
| 107 | +void checkChunk(BytePack* ptr, int nWorkers) { |
| 108 | + BytePack tmp[UNROLL]; |
| 109 | + |
| 110 | +#pragma unroll 8 |
| 111 | + for (int j = 0; j < UNROLL; j++) { |
| 112 | + tmp[j] = ptr[nWorkers * j]; |
| 113 | + } |
| 114 | +// Then check each BytePack in the tmp buffer |
| 115 | +#pragma unroll 8 |
| 116 | + for (int j = 0; j < UNROLL; j++) { |
| 117 | + CheckBytePack<T, sizeof(BytePack) / sizeof(T)>::check(tmp + j); |
| 118 | + } |
| 119 | + // Note: we separate the check from the load for efficient loading |
| 120 | +} |
| 121 | + |
| 122 | +// Align address of `ptr` up, to the alignment of `T` |
| 123 | +#define ALIGN_UP(ptr, T) \ |
| 124 | + (((uintptr_t)ptr + sizeof(T) - 1) / sizeof(T) * sizeof(T)) |
| 125 | + |
| 126 | +template <typename T> |
| 127 | +struct checkForNaN { |
| 128 | + void operator()(sycl::nd_item<1> item) const { |
| 129 | + constexpr int EltPerPack = sizeof(BytePack) / sizeof(T); |
| 130 | + |
| 131 | + size_t offset = item.get_global_id(0); |
| 132 | + |
| 133 | + // Align input address up to BytePack in case it is not |
| 134 | + T* ptrAlign = (T*)ALIGN_UP(data, BytePack); |
| 135 | + size_t preProcElts = |
| 136 | + std::min<size_t>(static_cast<size_t>(ptrAlign - data), size); |
| 137 | + |
| 138 | + size_t size_left = size; |
| 139 | + |
| 140 | + if (offset < preProcElts) { |
| 141 | + if (at::_isnan(data[offset])) |
| 142 | + assert(0); |
| 143 | + } |
| 144 | + size_left -= preProcElts; |
| 145 | + |
| 146 | + BytePack* ptr = (BytePack*)ptrAlign; |
| 147 | + size_t sizeInBP = size_left * sizeof(T) / sizeof(BytePack); |
| 148 | + size_t loopSize = item.get_global_range(0) * UNROLL; |
| 149 | + |
| 150 | + for (; offset + loopSize <= sizeInBP; offset += loopSize) { |
| 151 | + checkChunk<T>(ptr + offset, item.get_global_range(0)); |
| 152 | + } |
| 153 | + |
| 154 | + for (; offset < sizeInBP; offset += item.get_global_range(0)) { |
| 155 | + BytePack tmp = ptr[offset]; |
| 156 | + CheckBytePack<T, EltPerPack>::check(&tmp); |
| 157 | + } |
| 158 | + |
| 159 | + if (item.get_local_id(0) < size_left % EltPerPack) { |
| 160 | + T* tailPtr = (T*)(ptr + sizeInBP); |
| 161 | + if (at::_isnan(tailPtr[item.get_local_id(0)])) |
| 162 | + assert(0); |
| 163 | + } |
| 164 | + } |
| 165 | + checkForNaN(T* data, size_t size) : data(data), size(size) {} |
| 166 | + |
| 167 | + private: |
| 168 | + T* data; |
| 169 | + size_t size; |
| 170 | +}; |
| 171 | + |
| 172 | +template <typename T> |
| 173 | +void checkfornan_impl_xpu( |
| 174 | + const at::Tensor& tensor, |
| 175 | + at::xpu::XPUStream& stream) { |
| 176 | + // skip check for non float types |
| 177 | + if (!torch::is_floating_point(tensor)) { |
| 178 | + return; |
| 179 | + } |
| 180 | + |
| 181 | + int64_t maxNumThreadsPerBlock = syclMaxWorkGroupSize<checkForNaN<T>>(); |
| 182 | + |
| 183 | + const size_t numThreadsPerBlock = |
| 184 | + std::min<size_t>(maxNumThreadsPerBlock, tensor.numel()); |
| 185 | + |
| 186 | + if (!(numThreadsPerBlock > 0)) { |
| 187 | + return; |
| 188 | + } |
| 189 | + |
| 190 | + int64_t numBlocks = |
| 191 | + (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock; |
| 192 | + auto global_range{numBlocks * numThreadsPerBlock}; |
| 193 | + auto local_range{numThreadsPerBlock}; |
| 194 | + |
| 195 | + using Kernel = checkForNaN<T>; |
| 196 | + auto kfn = Kernel(tensor.data_ptr<T>(), tensor.numel()); |
| 197 | + |
| 198 | + sycl_kernel_submit(global_range, local_range, stream.queue(), kfn); |
| 199 | +} |
| 200 | + |
| 201 | +// CHECK if a Tensor contains NAN in any of its element |
| 202 | +void checkForNan(const at::Tensor& tensor, at::xpu::XPUStream& stream) { |
| 203 | + AT_DISPATCH_FLOATING_TYPES_AND4( |
| 204 | + at::ScalarType::Half, |
| 205 | + at::ScalarType::BFloat16, |
| 206 | + at::ScalarType::Float8_e4m3fn, |
| 207 | + at::ScalarType::Float8_e5m2, |
| 208 | + tensor.scalar_type(), |
| 209 | + "checkForNaN_XPU", |
| 210 | + [&]() { checkfornan_impl_xpu<scalar_t>(tensor, stream); }); |
| 211 | +} |
| 212 | + |
| 213 | +} // namespace c10d |
0 commit comments