@@ -201,11 +201,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
201
201
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202
202
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
203
203
cudaStream_t stream) {
204
- const int threads = 128 ;
205
204
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206
205
if (src3_nb1 == sizeof (float )) {
207
206
// Mamba-2
208
207
if (d_state == 128 ) {
208
+ const int threads = 128 ;
209
209
GGML_ASSERT (d_state % threads == 0 );
210
210
// NOTE: can be any power of two between 4 and 64
211
211
const int splitH = 16 ;
@@ -215,10 +215,21 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
215
215
src0, src1, src2, src3, src4, src5, src6, dst,
216
216
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
217
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
+ } else if (d_state == 256 ) { // Falcon-H1
219
+ const int threads = 256 ;
220
+ // NOTE: can be any power of two between 8 and 64
221
+ const int splitH = 16 ;
222
+ GGML_ASSERT (head_dim % splitH == 0 );
223
+ const dim3 blocks ((n_head * head_dim + (splitH - 1 )) / splitH, n_seq, 1 );
224
+ ssm_scan_f32_group<16 , 256 ><<<blocks, threads, 0 , stream>>> (
225
+ src0, src1, src2, src3, src4, src5, src6, dst,
226
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
227
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
228
} else {
219
- GGML_ABORT (" doesn't support d_state!=128." );
229
+ GGML_ABORT (" doesn't support d_state!=( 128 or 256) ." );
220
230
}
221
231
} else {
232
+ const int threads = 128 ;
222
233
// Mamba-1
223
234
GGML_ASSERT (n_head % threads == 0 );
224
235
GGML_ASSERT (head_dim == 1 );
0 commit comments