Skip to content

Commit f518d9c

Browse files
authored
Merge pull request #132 from ROCm/shbiswas/sparse_group_opt
Shbiswas/sparse group opt
2 parents 648e57a + 735b803 commit f518d9c

20 files changed

+897
-299
lines changed

.github/scripts/utils_build.bash

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ install_build_tools () {
370370
patchelf \
371371
rhash \
372372
scikit-build \
373+
tbb-devel \
373374
tbb \
374375
wheel \
375376
xz \

cmake/modules/CppLibrary.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ function(cpp_library)
168168
target_link_libraries(${lib_name} PUBLIC OpenMP::OpenMP_CXX)
169169
endif()
170170

171+
if(NOT TARGET TBB::tbb)
172+
find_package(TBB QUIET)
173+
endif()
174+
if(TBB_FOUND)
175+
target_link_libraries(${lib_name} PUBLIC TBB::tbb)
176+
else()
177+
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
178+
if(TBB_LIB)
179+
target_link_libraries(${lib_name} PUBLIC ${TBB_LIB})
180+
endif()
181+
endif()
182+
171183
# Add sanitizer options if needed
172184
if(args_SANITIZER_OPTIONS)
173185
target_link_options(${lib_name} PUBLIC

cmake/modules/GpuCppLibrary.cmake

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,18 @@ function(gpu_cpp_library)
302302
list(APPEND library_dependencies ${NVML_LIB_PATH})
303303
endif()
304304

305+
if(NOT TARGET TBB::tbb)
306+
find_package(TBB QUIET)
307+
endif()
308+
if(TBB_FOUND)
309+
list(APPEND library_dependencies TBB::tbb)
310+
else()
311+
find_library(TBB_LIB NAMES tbb tbb12 HINTS $ENV{CONDA_PREFIX}/lib /usr/lib/x86_64-linux-gnu /usr/local/lib /lib/x86_64-linux-gnu)
312+
if(TBB_LIB)
313+
list(APPEND library_dependencies ${TBB_LIB})
314+
endif()
315+
endif()
316+
305317
# Link against the external libraries as needed
306318
target_link_libraries(${lib_name} PRIVATE ${library_dependencies})
307319

fbgemm_gpu/cmake/tbe_sources.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@
176176
"_nobag" if nobag else "",
177177
)
178178
for nobag in [
179-
True,
180179
False,
181180
]
182181
for weighted in (
@@ -495,7 +494,6 @@
495494
"_nobag" if nobag else "",
496495
)
497496
for nobag in [
498-
True,
499497
False,
500498
]
501499
for weighted in (

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def render_backward_templates(
5252
return
5353

5454
weighted_options = [True, False]
55-
nobag_options = [True, False] if (not is_gwd) else [False]
55+
nobag_options = (
56+
[True, False]
57+
if (not (is_gwd or kwargs.get("is_hip_optimized_backward")))
58+
else [False]
59+
)
5660
vbe_options = [True, False] if (kwargs.get("has_vbe_support")) else [False]
5761
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
5862
template = CodeTemplate.load(template_filepath)
@@ -327,8 +331,7 @@ def generate_backward_indices() -> None:
327331

328332
@staticmethod
329333
def generate_rocm_backward_split(**kwargs: Any) -> None:
330-
# Generate backward device kernels based on weighted (True/False), VBE
331-
# (True/False), no bag (True/False)
334+
# Generate backward device kernels based on weighted (True/False)
332335
template_filepath = (
333336
"training/backward/rocm/embedding_backward_split_device_kernel_template.hip"
334337
)
@@ -343,6 +346,7 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:
343346
"has_ssd_support": False,
344347
"dense": False,
345348
"gen_once": False,
349+
"is_hip_optimized_backward": True,
346350
},
347351
)
348352

fbgemm_gpu/codegen/training/backward/embedding_backward_dense_host_cpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ Tensor split_embedding_codegen_lookup_dense_function(
172172
c10::SymInt /* max_B = -1 */,
173173
c10::SymInt /* max_B_feature_rank = -1 */,
174174
c10::SymInt /* vbe_output_size = -1 */,
175-
bool /* mixed_D = true */) {
175+
bool /* mixed_D = false */) {
176176
return SplitLookupFunction_Dense_Op::apply(
177177
host_weights,
178178
weights_offsets,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ class {{ autograd_func }} :
960960

961961
#ifdef USE_ROCM
962962
constexpr int32_t BT_block_size = 64;
963-
constexpr int32_t max_segment_length_per_warp = 64;
963+
constexpr int32_t max_segment_length_per_warp = 16384;
964964
#else
965965
constexpr int32_t BT_block_size = 32;
966966
constexpr int32_t max_segment_length_per_warp = 32;
@@ -1116,7 +1116,7 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
11161116
{%- else %}
11171117
const c10::SymInt vbe_output_size = -1,
11181118
{%- endif %}
1119-
const bool mixed_D = true
1119+
const bool mixed_D = false
11201120
) {
11211121
// TODO: refactor into macro
11221122
{%- if has_gpu_support %}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

100644100755
Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
#include "fbgemm_gpu/utils/assert_macros.h"
2424
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2525

26+
{%- if is_rocm %}
27+
#include "fbgemm_gpu/rocm/cdna_guard.h"
28+
{%- endif %}
29+
2630
using Tensor = at::Tensor;
2731
using namespace fbgemm_gpu;
2832

@@ -209,8 +213,127 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
209213
2, offset_idx + D_emb <= weights_numel, offset_idx
210214
)
211215
{%- endif %}
216+
int32_t j = 0;
217+
{%- if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe %}
218+
// Currently for split_embedding_codegen_grad_indice_weights_kernel only
219+
if (placement != PlacementType::MANAGED_CACHING) {
220+
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
221+
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
222+
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
223+
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
224+
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);
225+
226+
at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
227+
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
228+
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
229+
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;
230+
231+
const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
232+
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
233+
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
234+
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);
235+
236+
#pragma unroll kFixedMaxVecsPerThread
237+
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
238+
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;
239+
240+
Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
241+
weight0 = weight_row0.load(d);
242+
weight1 = weight_row1.load(d);
243+
weight2 = weight_row2.load(d);
244+
weight3 = weight_row3.load(d);
245+
246+
grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
247+
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
248+
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
249+
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
250+
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
251+
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
252+
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
253+
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
254+
}
255+
256+
grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
257+
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
258+
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
259+
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);
260+
261+
if (threadIdx.x == 0) {
262+
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
263+
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
264+
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
265+
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
266+
}
267+
}
268+
} else {
269+
for (; j < kWarpSize && l_start + j + 3 < L; j += 4) {
270+
const auto offset_idx_j0 = shfl_sync(offset_idx, j);
271+
const auto offset_idx_j1 = shfl_sync(offset_idx, j+1);
272+
const auto offset_idx_j2 = shfl_sync(offset_idx, j+2);
273+
const auto offset_idx_j3 = shfl_sync(offset_idx, j+3);
274+
275+
const auto cache_idx_j0 = shfl_sync(cache_idx, j);
276+
const auto cache_idx_j1 = shfl_sync(cache_idx, j+1);
277+
const auto cache_idx_j2 = shfl_sync(cache_idx, j+2);
278+
const auto cache_idx_j3 = shfl_sync(cache_idx, j+3);
279+
280+
at::acc_type<cache_t, true> grad_indice_weight0 = 0.0;
281+
at::acc_type<cache_t, true> grad_indice_weight1 = 0.0;
282+
at::acc_type<cache_t, true> grad_indice_weight2 = 0.0;
283+
at::acc_type<cache_t, true> grad_indice_weight3 = 0.0;
284+
285+
const auto weight_row0 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j0], D);
286+
const auto weight_row1 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j1], D);
287+
const auto weight_row2 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j2], D);
288+
const auto weight_row3 = WeightRowAccessor<emb_t, at::acc_type<cache_t, true>>(&weights[offset_idx_j3], D);
289+
290+
#pragma unroll kFixedMaxVecsPerThread
291+
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && (kWarpSize * vec + threadIdx.x) * kVecWidth < D; ++vec) {
292+
const int32_t d = (kWarpSize * vec + threadIdx.x) * kVecWidth;
293+
294+
Vec4T<at::acc_type<cache_t, true>> weight0, weight1, weight2, weight3;
295+
weight0 = (cache_idx_j0 != kCacheLocationMissing) ?
296+
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j0][d]) :
297+
weight_row0.load(d);
298+
299+
weight1 = (cache_idx_j1 != kCacheLocationMissing) ?
300+
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j1][d]) :
301+
weight_row1.load(d);
302+
303+
weight2 = (cache_idx_j2 != kCacheLocationMissing) ?
304+
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j2][d]) :
305+
weight_row2.load(d);
306+
307+
weight3 = (cache_idx_j3 != kCacheLocationMissing) ?
308+
Vec4T<at::acc_type<cache_t, true>>(&lxu_cache_weights[cache_idx_j3][d]) :
309+
weight_row3.load(d);
310+
311+
312+
grad_indice_weight0 += weight0.acc.x * grad_out[vec].acc.x + weight0.acc.y * grad_out[vec].acc.y +
313+
weight0.acc.z * grad_out[vec].acc.z + weight0.acc.w * grad_out[vec].acc.w;
314+
grad_indice_weight1 += weight1.acc.x * grad_out[vec].acc.x + weight1.acc.y * grad_out[vec].acc.y +
315+
weight1.acc.z * grad_out[vec].acc.z + weight1.acc.w * grad_out[vec].acc.w;
316+
grad_indice_weight2 += weight2.acc.x * grad_out[vec].acc.x + weight2.acc.y * grad_out[vec].acc.y +
317+
weight2.acc.z * grad_out[vec].acc.z + weight2.acc.w * grad_out[vec].acc.w;
318+
grad_indice_weight3 += weight3.acc.x * grad_out[vec].acc.x + weight3.acc.y * grad_out[vec].acc.y +
319+
weight3.acc.z * grad_out[vec].acc.z + weight3.acc.w * grad_out[vec].acc.w;
320+
}
321+
322+
grad_indice_weight0 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight0);
323+
grad_indice_weight1 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight1);
324+
grad_indice_weight2 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight2);
325+
grad_indice_weight3 = warpReduceAllSum<at::acc_type<cache_t, true>>(grad_indice_weight3);
212326

213-
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
327+
if (threadIdx.x == 0) {
328+
grad_indice_weights[indices_start + l_start + j] = grad_indice_weight0;
329+
grad_indice_weights[indices_start + l_start + j+1] = grad_indice_weight1;
330+
grad_indice_weights[indices_start + l_start + j+2] = grad_indice_weight2;
331+
grad_indice_weights[indices_start + l_start + j+3] = grad_indice_weight3;
332+
}
333+
}
334+
}
335+
{%- endif %}{#-/* if is_rocm and not ssd and not dense and not use_vec_blocking and not vbe */#}
336+
for (; j < kWarpSize && l_start + j < L; ++j) {
214337
const auto offset_idx_j = shfl_sync(offset_idx, j);
215338
{%- if not dense %}
216339
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
@@ -359,6 +482,15 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
359482
auto aligned_grad_output = aligned_grad_output_tensor_for_cuda_backwards(grad_output);
360483

361484
CUDA_DEVICE_GUARD(dev_weights);
485+
#ifdef USE_ROCM
486+
if (!rocm::is_supported_cdna()) {
487+
TORCH_WARN_ONCE("Running on non-CDNA architecture. Performance may be suboptimal.");
488+
}
489+
else {
490+
// Ensure we're running on a supported CDNA architecture (including MI350)
491+
TORCH_WARN_ONCE("Running on CDNA architecture");
492+
}
493+
#endif
362494

363495
const auto T = D_offsets.size(0) - 1;
364496
TORCH_CHECK_GT(T, 0);

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232

3333
{%- set locs_or_addrs_tensor = "ssd_row_addrs" if ssd else "lxu_cache_locations" %}
3434
{%- set locs_or_addrs_type = "int64_t" if ssd else "int32_t" %}
35+
{%- set is_optimized_hip_kernel_supported_mode = is_rocm and
36+
optimizer == "rowwise_adagrad" and
37+
not dense and
38+
not nobag and
39+
not is_index_select and
40+
not is_gwd_kernel and
41+
not vbe and
42+
not ssd %}
3543

3644
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
3745
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
@@ -538,7 +546,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
538546

539547
{%- endif %}
540548

541-
{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and not dense and not is_gwd_kernel and not vbe and not ssd %}
549+
{%- if is_optimized_hip_kernel_supported_mode %}
542550
#include <hip/hip_runtime.h>
543551
#include <hip/hip_fp16.h>
544552
#include "fbgemm_gpu/rocm/split_embeddings_common.h"
@@ -612,12 +620,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
612620
{{ args.split_kernel_args | replace_pta_namespace() | join(",\n ") }}
613621
{%- endif %}
614622
) {
615-
{%- if not nobag %}
616623
int32_t T = D_offsets.size(0) - 1;
617-
{%- else %}
618-
int32_t T = weights_offsets.size(0);
619-
{%- endif %}
620-
621624
auto p_output_grad = grad_output.data();
622625
auto p_emb_table = dev_weights.data();
623626
auto p_hash_size_cumsum = hash_size_cumsum.data();
@@ -632,8 +635,6 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
632635
constexpr int32_t segment_prefetch = 2;
633636
constexpr int32_t segment_unroll = 8;
634637
constexpr int32_t segment_split = 0;
635-
auto batch = grad_output.size(0);
636-
auto num_rows = dev_weights.size(0) / T / max_D;
637638
{%- if weighted %}
638639
constexpr bool is_weighted = true;
639640
{%- else %}
@@ -646,30 +647,15 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
646647
// weight_decay(_mode) is supplied as args.split_function_args_no_defaults
647648
opt_karg.weight_decay_mode = weight_decay_mode_v;
648649
opt_karg.weight_decay = weight_decay;
649-
auto batch_mdiv = [](uint32_t d) -> rocm::magic_div_u32_t {
650-
assert(d >= 1 && d <= INT32_MAX);
651-
uint8_t shift;
652-
for(shift = 0; shift < 32; shift++)
653-
if((1U << shift) >= d)
654-
break;
655-
656-
uint64_t one = 1;
657-
uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1;
658-
assert(magic <= 0xffffffffUL);
659-
660-
rocm::magic_div_u32_t result;
661-
result.magic = magic;
662-
result.shift = shift;
663-
return result;
664-
}(batch);
650+
665651
rocm::split_tbe_backward_hip_kernel_{{kdesc}}<
666-
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, embedding_dim, weight_decay_mode_v>,
652+
rocm::{{optimizer}}_optimizer_t<cache_t, emb_t, index_t, embedding_dim, weight_decay_mode_v>,
667653
rocm::{{optimizer}}_kernel_arg_t,
668654
emb_t,
669655
cache_t,
670656
grad_t,
671657
index_t,
672-
BLOCK_SIZE,
658+
BLOCK_SIZE_ROCM,
673659
embedding_dim,
674660
segment_prefetch,
675661
segment_unroll,
@@ -680,16 +666,11 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
680666
p_sorted_linear_indices_run,
681667
p_sorted_linear_indices_cumulative_run_lengths,
682668
p_sorted_linear_indices_num_runs,
683-
{%- if not nobag %}
684669
info_B_num_bits,
685670
info_B_mask,
686-
{%- endif %}
687671
p_sorted_infos,
688-
batch_mdiv,
689672
max_segment_length_per_warp,
690673
emb_dim,
691-
batch,
692-
num_rows,
693674
T,
694675
opt_karg
695676
{%- if weighted %}
@@ -784,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
784765
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
785766
{%- for cache_type in ['float', 'at::Half'] %}
786767
{%- for index_type in ['int32_t', 'int64_t'] %}
787-
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
768+
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
788769
{%- for kWeighDecayMode in [0, 1, 2] %}
789770
{{ hip_template_instantiation(
790771
emb_type,

0 commit comments

Comments
 (0)