Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
}
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
}

Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ void main() {

O = Ldiag*O;

[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }

uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;

coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ void main() {
}
}
O *= L;

const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);

data_d[iq3 * D * N + D * n + d] = O;
}
}
12 changes: 12 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,18 @@ void main() {
barrier();
}

#ifdef COOPMAT
[[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
[[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
}
#else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
#endif

const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;

Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }

coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);

coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
Expand Down Expand Up @@ -388,6 +390,8 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }

coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);

coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
Expand Down Expand Up @@ -428,6 +432,8 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }

coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);

coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
Expand Down Expand Up @@ -485,6 +491,7 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }

// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
Expand Down
19 changes: 11 additions & 8 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}

base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
base_dict["ACC_TYPE_MAX"] = f16acc ? "\"float16_t(65504.0)\"" : "\"uintBitsToFloat(0x7F7FFFFF)\"";

if (coopmat) {
base_dict["COOPMAT"] = "1";
Expand Down Expand Up @@ -437,8 +438,10 @@ void process_shaders() {

// flash attention
for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float";
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
std::map<std::string, std::string> fa_base_dict = base_dict;
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
fa_base_dict["ACC_TYPE_MAX"] = f16acc ? "\"float16_t(65504.0)\"" : "\"uintBitsToFloat(0x7F7FFFFF)\"";

for (const auto& tname : type_names) {
if (tname == "f32") {
Expand All @@ -449,30 +452,30 @@ void process_shaders() {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
}
}
Expand Down