Skip to content

Commit a57d1bc

Browse files
authored
cuda : support Falcon-H1 state size for SSM_SCAN (#14602)
1 parent cb9178f commit a57d1bc

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3335,8 +3335,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33353335
case GGML_OP_SSM_SCAN: {
33363336
if (op->src[3]->ne[0] == 1) {
33373337
// Mamba2
3338-
// (kernel only supports d_state == 128 && d_head % 16 == 0)
3339-
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
3338+
// (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
3339+
return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
33403340
} else {
33413341
// Mamba
33423342
// (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201201
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202202
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203203
cudaStream_t stream) {
204-
const int threads = 128;
205204
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206205
if (src3_nb1 == sizeof(float)) {
207206
// Mamba-2
208207
if (d_state == 128) {
208+
const int threads = 128;
209209
GGML_ASSERT(d_state % threads == 0);
210210
// NOTE: can be any power of two between 4 and 64
211211
const int splitH = 16;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
215215
src0, src1, src2, src3, src4, src5, src6, dst,
216216
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217217
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218+
} else if (d_state == 256) { // Falcon-H1
219+
const int threads = 256;
220+
// NOTE: can be any power of two between 8 and 64
221+
const int splitH = 16;
222+
GGML_ASSERT(head_dim % splitH == 0);
223+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
224+
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
225+
src0, src1, src2, src3, src4, src5, src6, dst,
226+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
227+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218228
} else {
219-
GGML_ABORT("doesn't support d_state!=128.");
229+
GGML_ABORT("doesn't support d_state!=(128 or 256).");
220230
}
221231
} else {
232+
const int threads = 128;
222233
// Mamba-1
223234
GGML_ASSERT(n_head % threads == 0);
224235
GGML_ASSERT(head_dim == 1);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5069,6 +5069,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
50695069

50705070
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
50715071
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
5072+
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
50725073

50735074
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
50745075
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));

0 commit comments

Comments
 (0)