@@ -402,120 +402,86 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
402402
403403void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
404404 if (self_kq_mask || self_kq_mask_swa) {
405- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406- if (cparams.causal_attn ) {
407- const int64_t n_kv = kv_self->n ;
408- const int64_t n_tokens = ubatch->n_tokens ;
409- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
410- const int64_t n_seqs = ubatch->n_seqs ;
411-
412- float * data = nullptr ;
413- float * data_swa = nullptr ;
414-
415- if (self_kq_mask) {
416- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
417- data = (float *) self_kq_mask->data ;
418- }
405+ const int64_t n_kv = kv_self->n ;
406+ const int64_t n_tokens = ubatch->n_tokens ;
407+ const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
408+ const int64_t n_seqs = ubatch->n_seqs ;
419409
420- if (self_kq_mask_swa) {
421- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
422- data_swa = (float *) self_kq_mask_swa->data ;
423- }
410+ float * data = nullptr ;
411+ float * data_swa = nullptr ;
424412
425- // For causal attention, use only the previous KV cells
426- // of the correct sequence for each token of the ubatch.
427- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
428- for (int h = 0 ; h < 1 ; ++h) {
429- for (int s = 0 ; s < n_seqs; ++s) {
430- const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
413+ if (self_kq_mask) {
414+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
415+ data = (float *) self_kq_mask->data ;
416+ }
431417
432- for (int j = 0 ; j < n_seq_tokens; ++j) {
433- const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
418+ if (self_kq_mask_swa) {
419+ GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask_swa->buffer ));
420+ data_swa = (float *) self_kq_mask_swa->data ;
421+ }
434422
435- for (int i = 0 ; i < n_kv; ++i) {
436- float f;
437- if (!kv_self->cells [i].has_seq_id (seq_id) || kv_self->cells [i].pos > pos) {
438- f = -INFINITY;
423+ // Use only the previous KV cells of the correct sequence for each token of the ubatch.
424+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
425+ // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
426+ // Causal mask:
427+ // xxx-------
428+ // xxxx------
429+ // xxxxx-----
430+ // Non-causal mask:
431+ // xxxxx-----
432+ // xxxxx-----
433+ // xxxxx-----
434+ // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
435+ for (int h = 0 ; h < 1 ; ++h) {
436+ for (int s = 0 ; s < n_seqs; ++s) {
437+ const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
438+
439+ for (int j = 0 ; j < n_seq_tokens; ++j) {
440+ const llama_pos pos = ubatch->pos [s*n_seq_tokens + j];
441+ for (int i = 0 ; i < n_kv; ++i) {
442+ float f;
443+ // mask the token if:
444+ if (!kv_self->cells [i].has_seq_id (seq_id) // not the correct sequence
445+ || (cparams.causal_attn && kv_self->cells [i].pos > pos) // for causal, mask future tokens
446+ ) {
447+ f = -INFINITY;
448+ } else {
449+ if (hparams.use_alibi ) {
450+ f = -std::abs (kv_self->cells [i].pos - pos);
439451 } else {
440- if (hparams.use_alibi ) {
441- f = -std::abs (kv_self->cells [i].pos - pos);
442- } else {
443- f = 0 .0f ;
444- }
445- }
446-
447- if (data) {
448- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
452+ f = 0 .0f ;
449453 }
454+ }
450455
451- // may need to cut off old tokens for sliding window
452- if (data_swa) {
453- if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
454- f = -INFINITY;
455- }
456- data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
457- }
456+ if (data) {
457+ data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
458458 }
459- }
460- }
461459
462- if (data) {
463- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
464- for (int j = 0 ; j < n_kv; ++j) {
465- data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
460+ // may need to cut off old tokens for sliding window
461+ if (data_swa) {
462+ if (pos - kv_self->cells [i].pos >= (int32_t )hparams.n_swa ) {
463+ f = -INFINITY;
464+ }
465+ data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
466466 }
467467 }
468468 }
469+ }
469470
470- if (data_swa) {
471- for ( int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i ) {
472- for (int j = 0 ; j < n_kv ; ++j ) {
473- data_swa[h*( n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
474- }
471+ // mask padded tokens
472+ if (data ) {
473+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD) ; ++i ) {
474+ for ( int j = 0 ; j < n_kv; ++j) {
475+ data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
475476 }
476477 }
477478 }
478- } else {
479- const int64_t n_tokens = ubatch->n_tokens ;
480- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
481- const int64_t n_seqs = ubatch->n_seqs ;
482- // when using kv cache, the mask needs to match the kv cache size
483- const int64_t n_stride = n_tokens;
484479
485- GGML_ASSERT (ggml_backend_buffer_is_host (self_kq_mask->buffer ));
486-
487- float * data = (float *) self_kq_mask->data ;
488-
489- for (int h = 0 ; h < 1 ; ++h) {
490- for (int s1 = 0 ; s1 < n_seqs; ++s1) {
491- const llama_seq_id seq_id = ubatch->seq_id [s1][0 ];
492-
493- for (int j = 0 ; j < n_seq_tokens; ++j) {
494- const int32_t tj = s1*n_seq_tokens + j;
495-
496- for (int s0 = 0 ; s0 < n_seqs; ++s0) {
497- for (int i = 0 ; i < n_seq_tokens; ++i) {
498- const int32_t ti = s0*n_seq_tokens + i;
499- float f = -INFINITY;
500-
501- for (int s = 0 ; s < ubatch->n_seq_id [s0]; ++s) {
502- if (ubatch->seq_id [s0][s] == seq_id) {
503- if (hparams.use_alibi ) {
504- f = -std::abs (ubatch->pos [ti] - ubatch->pos [tj]);
505- } else {
506- f = 0 .0f ;
507- }
508- break ;
509- }
510- }
511-
512- data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
513- }
514- }
515-
516- for (int i = n_tokens; i < n_stride; ++i) {
517- data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
518- }
480+ // mask padded tokens
481+ if (data_swa) {
482+ for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i) {
483+ for (int j = 0 ; j < n_kv; ++j) {
484+ data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
519485 }
520486 }
521487 }
0 commit comments