|
23 | 23 | #include "fbgemm_gpu/utils/assert_macros.h" |
24 | 24 | #include "fbgemm_gpu/utils/kernel_launcher.cuh" |
25 | 25 |
|
| 26 | +{%- if is_rocm %} |
| 27 | +#include "fbgemm_gpu/rocm/cdna_guard.h" |
| 28 | +{%- endif %} |
| 29 | + |
26 | 30 | using Tensor = at::Tensor; |
27 | 31 | using namespace fbgemm_gpu; |
28 | 32 |
|
@@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void |
209 | 213 | 2, offset_idx + D_emb <= weights_numel, offset_idx |
210 | 214 | ) |
211 | 215 | {%- endif %} |
| 216 | + int32_t j = 0; |
| 217 | + {%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %} |
| 218 | + // Currently for split_embedding_codegen_grad_indice_weights_kernel only |
| 219 | + if (placement != PlacementType::MANAGED_CACHING) { |
| 220 | + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { |
| 221 | + const auto offset_idx_j0 = shfl_sync(offset_idx, j); |
| 222 | + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); |
| 223 | + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); |
| 224 | + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); |
| 225 | + |
| 226 | + at::acc_type<cache_t, true> grad_indice_weight0 = 0.0; |
| 227 | + at::acc_type<cache_t, true> grad_indice_weight1 = 0.0; |
| 228 | + at::acc_type<cache_t, true> grad_indice_weight2 = 0.0; |
| 229 | + at::acc_type<cache_t, true> grad_indice_weight3 = 0.0; |
| 230 | + |
| 231 | + const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D); |
| 232 | + const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D); |
| 233 | + const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D); |
| 234 | + const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D); |
| 235 | + |
| 236 | + #pragma unroll kFixedMaxVecsPerThread |
| 237 | + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { |
| 238 | + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; |
| 239 | + |
| 240 | + Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3; |
| 241 | + weight0 = weight_row0.load(d); |
| 242 | + weight1 = weight_row1.load(d); |
| 243 | + weight2 = weight_row2.load(d); |
| 244 | + weight3 = weight_row3.load(d); |
| 245 | + |
| 246 | + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + |
| 247 | + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; |
| 248 | + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + |
| 249 | + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; |
| 250 | + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + |
| 251 | + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; |
| 252 | + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + |
| 253 | + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; |
| 254 | + } |
| 255 | + |
| 256 | + grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0); |
| 257 | + grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1); |
| 258 | + grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2); |
| 259 | + grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3); |
| 260 | + |
| 261 | + if (threadIdx.x == 0) { |
| 262 | + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; |
| 263 | + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; |
| 264 | + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; |
| 265 | + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; |
| 266 | + } |
| 267 | + } |
| 268 | + } else { |
| 269 | + for (; j < kWarpSize && l_start + j + 3 < L; j += 4) { |
| 270 | + const auto offset_idx_j0 = shfl_sync(offset_idx, j); |
| 271 | + const auto offset_idx_j1 = shfl_sync(offset_idx, j+1); |
| 272 | + const auto offset_idx_j2 = shfl_sync(offset_idx, j+2); |
| 273 | + const auto offset_idx_j3 = shfl_sync(offset_idx, j+3); |
| 274 | + |
| 275 | + const auto cache_idx_j0 = shfl_sync(cache_idx, j); |
| 276 | + const auto cache_idx_j1 = shfl_sync(cache_idx, j+1); |
| 277 | + const auto cache_idx_j2 = shfl_sync(cache_idx, j+2); |
| 278 | + const auto cache_idx_j3 = shfl_sync(cache_idx, j+3); |
| 279 | + |
| 280 | + at::acc_type<cache_t, true> grad_indice_weight0 = 0.0; |
| 281 | + at::acc_type<cache_t, true> grad_indice_weight1 = 0.0; |
| 282 | + at::acc_type<cache_t, true> grad_indice_weight2 = 0.0; |
| 283 | + at::acc_type<cache_t, true> grad_indice_weight3 = 0.0; |
| 284 | + |
| 285 | + const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D); |
| 286 | + const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D); |
| 287 | + const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D); |
| 288 | + const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D); |
| 289 | + |
| 290 | + #pragma unroll kFixedMaxVecsPerThread |
| 291 | + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) { |
| 292 | + const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth; |
| 293 | + |
| 294 | + Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3; |
| 295 | + weight0 = (cache_idx_j0 != kCacheLocationMissing) ? |
| 296 | + Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) : |
| 297 | + weight_row0.load(d); |
| 298 | + |
| 299 | + weight1 = (cache_idx_j1 != kCacheLocationMissing) ? |
| 300 | + Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) : |
| 301 | + weight_row1.load(d); |
| 302 | + |
| 303 | + weight2 = (cache_idx_j2 != kCacheLocationMissing) ? |
| 304 | + Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) : |
| 305 | + weight_row2.load(d); |
| 306 | + |
| 307 | + weight3 = (cache_idx_j3 != kCacheLocationMissing) ? |
| 308 | + Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) : |
| 309 | + weight_row3.load(d); |
| 310 | + |
| 311 | + |
| 312 | + grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y + |
| 313 | + weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w; |
| 314 | + grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y + |
| 315 | + weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w; |
| 316 | + grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y + |
| 317 | + weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w; |
| 318 | + grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y + |
| 319 | + weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w; |
| 320 | + } |
| 321 | + |
| 322 | + grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0); |
| 323 | + grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1); |
| 324 | + grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2); |
| 325 | + grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3); |
212 | 326 |
|
213 | | - for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) { |
| 327 | + if (threadIdx.x == 0) { |
| 328 | + grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0; |
| 329 | + grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1; |
| 330 | + grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2; |
| 331 | + grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3; |
| 332 | + } |
| 333 | + } |
| 334 | + } |
| 335 | + {%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#} |
| 336 | + for (; j < kWarpSize && l_start + j < L; ++j) { |
214 | 337 | const auto offset_idx_j = shfl_sync(offset_idx, j); |
215 | 338 | {%- if not dense %} |
216 | 339 | const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j); |
@@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda( |
359 | 482 | auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output); |
360 | 483 |
|
361 | 484 | CUDA_DEVICE_GUARD(dev_weights); |
| 485 | + #ifdef USE_ROCM |
| 486 | + if (!rocm::is_supported_cdna()) { |
| 487 | + TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal."); |
| 488 | + } |
| 489 | + else { |
| 490 | + // Ensure we're running on a supported CDNA architecture (including MI350) |
| 491 | + TORCH_WARN_ONCE("Running on CDNA architecture"); |
| 492 | + } |
| 493 | + #endif |
362 | 494 |
|
363 | 495 | const auto T = D_offsets.size(0) - 1; |
364 | 496 | TORCH_CHECK_GT(T, 0); |
|
0 commit comments