Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,20 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k

// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
auto block_ms = std::vector{64, 128, 256};
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};

std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
if(get_env<int>("ENABLE_SWAPAB")){
block_ms = std::vector{32}; // 32, 64
block_ns = std::vector{256}; // 256 for H20, and can choose 64, 128, 256
}


// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
Expand Down Expand Up @@ -214,7 +223,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
MulticastConfig best_multicast_config = {1, true};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);
Expand Down
2 changes: 1 addition & 1 deletion csrc/jit_kernels/heuristics/sm100.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ struct SM100ArchSpec {
const int& num_sms) {
// TODO: support other layouts
return {
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
false,
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
};
}

Expand Down
25 changes: 20 additions & 5 deletions csrc/jit_kernels/heuristics/sm90.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,15 @@ struct SM90ArchSpec {

// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
// Or too many register spills
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
return false;

if(get_env<int>("ENABLE_SWAPAB")){
if (block_n != 64 and block_n != 128 and block_n != 256)
return false;
}else{
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
return false;
}

// Avoid bank conflicts for FP32 output
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
return false;
Expand All @@ -71,13 +77,21 @@ struct SM90ArchSpec {
const int& num_sms) {
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
is_multicast_legal(m, block_m, 2, num_sms, false)
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
};
}

static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
int tile = 64;
if(get_env<int>("ENABLE_SWAPAB")){
tile = block_n;
}else{
tile = block_m;
}
return ThreadConfig::sm90(128, (tile > 64 ? 2 : 1) * 128);
}

static int get_smem_cd_size(const KernelType& kernel_type,
Expand All @@ -102,7 +116,8 @@ struct SM90ArchSpec {

static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
const auto& use_uniform_sfb = get_env<int>("ENABLE_SWAPAB") ? (block_n / 64):(block_k % block_n == 0 ? 1 : 2);

return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
}

Expand Down
9 changes: 8 additions & 1 deletion csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@ class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime>
};

static std::string generate_impl(const Args& args) {

const char* kernel_name =
get_env<int>("ENABLE_SWAPAB") ?
"swapAB_sm90_fp8_gemm_1d2d_impl" :
"sm90_fp8_gemm_1d2d_impl";

return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>

using namespace deep_gemm;

static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
auto ptr = reinterpret_cast<void*>(&{}<
{}, {}, {},
{},
{}, {}, {},
Expand All @@ -47,6 +53,7 @@ static void __instantiate_kernel() {{
>);
}};
)",
kernel_name,
// TODO: add CD dtype
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
Expand Down
11 changes: 11 additions & 0 deletions deep_gemm/include/deep_gemm/common/sm90_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ struct SM90_U32x2_STSM_N {
}
};

template <typename dtype_t>
struct SM90_U32x2_STSM_T
{
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
{
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
"r"(src[1]));
}
};

__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
Expand Down
6 changes: 6 additions & 0 deletions deep_gemm/include/deep_gemm/common/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ __device__ __forceinline__ float ld_shared(const float* ptr) {
return ret;
}

__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}

__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
Expand Down
Loading