@@ -1348,9 +1348,10 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
1348
1348
int const batch_idx = blockIdx.z ;
1349
1349
1350
1350
// The decoder sequence length.
1351
- int const decoder_seq_len = params.seq_lens [batch_idx];
1351
+ // Spec decoding not supported for cross-attention at the moment so we can set 1 and batch_idx here
1352
+ int const decoder_seq_len = params.generation_phase ? 1 : params.seq_lens [batch_idx];
1352
1353
// The decoder sequence offset.
1353
- int const decoder_seq_offset = params.cu_seq_lens [batch_idx];
1354
+ int const decoder_seq_offset = params.generation_phase ? batch_idx : params. cu_seq_lens [batch_idx];
1354
1355
// The decoder cache sequence length (includes the current input).
1355
1356
int const decoder_cache_seq_len = params.cache_seq_lens [batch_idx];
1356
1357
// The encoder sequence length.
@@ -1411,45 +1412,49 @@ __global__ void updateKVCacheForCrossAttention(QKVPreprocessingParams<T, KVCache
1411
1412
}
1412
1413
}
1413
1414
1414
- // Encoder tokens (i.e. KV tokens).
1415
- if (head_idx == (kv_head_idx * params.qheads_per_kv_head ) && token_idx < encoder_seq_len
1416
- && store_encoder_kv_cache && params.kv_cache_buffer .data != nullptr )
1415
+ if (!params.generation_phase )
1417
1416
{
1418
- // The global token idx in all sequences.
1419
- int global_token_idx = token_idx + encoder_seq_offset;
1420
-
1421
- // The memory offset.
1422
- auto const src_k_idx = static_cast <size_t >(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1423
- auto const src_v_idx
1424
- = static_cast <size_t >(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1425
-
1426
- // Only load K,V tokens from encoder qkv input.
1427
- auto k = *reinterpret_cast <VecT const *>(¶ms.cross_kv_input [src_k_idx]);
1428
- auto v = *reinterpret_cast <VecT const *>(¶ms.cross_kv_input [src_v_idx]);
1429
-
1430
- // The kv cache pointers.
1431
- auto k_cache_block_ptr
1432
- = reinterpret_cast <TCache*>(params.kv_cache_buffer .getKBlockPtr (batch_idx, token_idx));
1433
- auto v_cache_block_ptr
1434
- = reinterpret_cast <TCache*>(params.kv_cache_buffer .getVBlockPtr (batch_idx, token_idx));
1435
- // The vector idx in the cache block.
1436
- auto block_vec_idx
1437
- = params.kv_cache_buffer .getKVLocalIdx (token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1438
-
1439
- // Store K and V to the cache.
1440
- // INT8/FP8 kv cache.
1441
- if constexpr (sizeof (TCache) == 1 )
1442
- {
1443
- // The element index inside the block.
1444
- auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1445
- // Store 8bits kv cache.
1446
- mmha::store_8bits_vec (k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1447
- mmha::store_8bits_vec (v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1448
- }
1449
- else
1417
+ // Encoder tokens (i.e. KV tokens).
1418
+ if (head_idx == (kv_head_idx * params.qheads_per_kv_head ) && token_idx < encoder_seq_len
1419
+ && store_encoder_kv_cache && params.kv_cache_buffer .data != nullptr )
1450
1420
{
1451
- reinterpret_cast <VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1452
- reinterpret_cast <VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1421
+ // The global token idx in all sequences.
1422
+ int global_token_idx = token_idx + encoder_seq_offset;
1423
+
1424
+ // The memory offset.
1425
+ auto const src_k_idx
1426
+ = static_cast <size_t >(global_token_idx) * params.kv_hidden_size * 2 + hidden_idx_kv;
1427
+ auto const src_v_idx
1428
+ = static_cast <size_t >(global_token_idx) * params.kv_hidden_size * 2 + src_v_offset + hidden_idx_kv;
1429
+
1430
+ // Only load K,V tokens from encoder qkv input.
1431
+ auto k = *reinterpret_cast <VecT const *>(¶ms.cross_kv_input [src_k_idx]);
1432
+ auto v = *reinterpret_cast <VecT const *>(¶ms.cross_kv_input [src_v_idx]);
1433
+
1434
+ // The kv cache pointers.
1435
+ auto k_cache_block_ptr
1436
+ = reinterpret_cast <TCache*>(params.kv_cache_buffer .getKBlockPtr (batch_idx, token_idx));
1437
+ auto v_cache_block_ptr
1438
+ = reinterpret_cast <TCache*>(params.kv_cache_buffer .getVBlockPtr (batch_idx, token_idx));
1439
+ // The vector idx in the cache block.
1440
+ auto block_vec_idx
1441
+ = params.kv_cache_buffer .getKVLocalIdx (token_idx, kv_head_idx, VECS_PER_HEAD, head_dim_vec_idx);
1442
+
1443
+ // Store K and V to the cache.
1444
+ // INT8/FP8 kv cache.
1445
+ if constexpr (sizeof (TCache) == 1 )
1446
+ {
1447
+ // The element index inside the block.
1448
+ auto block_elt_idx = block_vec_idx * ELTS_PER_VEC;
1449
+ // Store 8bits kv cache.
1450
+ mmha::store_8bits_vec (k_cache_block_ptr, k, block_elt_idx, scale_orig_quant);
1451
+ mmha::store_8bits_vec (v_cache_block_ptr, v, block_elt_idx, scale_orig_quant);
1452
+ }
1453
+ else
1454
+ {
1455
+ reinterpret_cast <VecT*>(k_cache_block_ptr)[block_vec_idx] = k;
1456
+ reinterpret_cast <VecT*>(v_cache_block_ptr)[block_vec_idx] = v;
1457
+ }
1453
1458
}
1454
1459
}
1455
1460
}
0 commit comments