@@ -453,6 +453,7 @@ class PagedAttentionGeneratorGQASingleToken : public PagedAttentionGeneratorSing
453
453
// GQA
454
454
jit.remove (" HEADS_PER_WI" );
455
455
jit.make (" HEADS_PER_WI" , heads_per_wi);
456
+
456
457
jit.make (" ITERATIONS_PER_KV_HEADS_GROUP" , ceil_div (kv_group_size, heads_per_wi));
457
458
jit.make (" HEADS_LEFTOVERS_NUM" , kv_group_size % heads_per_wi);
458
459
@@ -1363,11 +1364,15 @@ class PagedAttentionOptImpl : public SDPAImplBase {
1363
1364
const auto max_context_len = get_max_context_len (params);
1364
1365
num_of_partitions = ceil_div (max_context_len, partition_size);
1365
1366
}
1366
-
1367
+ bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
1368
+ #ifdef ENABLE_ONEDNN_FOR_GPU
1369
+ can_use_micro_sdpa &= has_stage (pa_sdpa_micro);
1370
+ #endif
1367
1371
GPU_DEBUG_TRACE_DETAIL << " get_internal_buffer_descs: stage = " << static_cast <size_t >(stage) << std::endl;
1368
1372
int64_t paged_attention_aligned_seq_len = -1 ;
1369
1373
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !params.is_dynamic ()) {
1370
- paged_attention_aligned_seq_len = get_aligned_seq_len (params, stage);
1374
+ auto block_size = get_query_block_size (stage, can_use_micro_sdpa);
1375
+ paged_attention_aligned_seq_len = get_aligned_seq_len (params, stage, block_size);
1371
1376
}
1372
1377
const auto target_seq_len = std::max<int64_t >(paged_attention_aligned_seq_len, 1 );
1373
1378
const auto indexes_buf_size = static_cast <int64_t >(ceil_div (target_seq_len, target_seq_len_block_size)) * element_size;
@@ -1431,10 +1436,6 @@ class PagedAttentionOptImpl : public SDPAImplBase {
1431
1436
}
1432
1437
}
1433
1438
1434
- bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
1435
- #ifdef ENABLE_ONEDNN_FOR_GPU
1436
- can_use_micro_sdpa &= has_stage (pa_sdpa_micro);
1437
- #endif
1438
1439
if (!can_use_micro_sdpa) {
1439
1440
// GENERATE/MIXED stages and PREFILL stage without micro_sdpa
1440
1441
internal_buffers.emplace_back (buf_elements_count * element_size, indexes_dt); // 5: softmax exp_sums
@@ -1610,7 +1611,6 @@ class PagedAttentionOptImpl : public SDPAImplBase {
1610
1611
const auto block_size = static_cast <int >(query_block_size);
1611
1612
for (int32_t j = 0 ; j < seq_length; j += block_size) {
1612
1613
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
1613
-
1614
1614
micro_sdpa_block_starts_and_gws_mapping_lock->operator [](micro_sdpa_index++) = block_start_pos;
1615
1615
micro_sdpa_block_starts_and_gws_mapping_lock->operator [](micro_sdpa_index++) = static_cast <int32_t >(i);
1616
1616
}
0 commit comments