Skip to content

Commit f666521

Browse files
[GPU]Fix macro multiple register and micro kernel block size issue (#31651)
### Details: - Fix macro 'HEADS_PER_WI' multiple register issue - Fix micro kernel block size issue when compute aligned_seq_len ### Tickets: - *CVS-171882* Co-authored-by: Chen Peter <[email protected]>
1 parent 32e7a5c commit f666521

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa/paged_attention_opt.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ class PagedAttentionGeneratorGQASingleToken : public PagedAttentionGeneratorSing
453453
// GQA
454454
jit.remove("HEADS_PER_WI");
455455
jit.make("HEADS_PER_WI", heads_per_wi);
456+
456457
jit.make("ITERATIONS_PER_KV_HEADS_GROUP", ceil_div(kv_group_size, heads_per_wi));
457458
jit.make("HEADS_LEFTOVERS_NUM", kv_group_size % heads_per_wi);
458459

@@ -1363,11 +1364,15 @@ class PagedAttentionOptImpl : public SDPAImplBase {
13631364
const auto max_context_len = get_max_context_len(params);
13641365
num_of_partitions = ceil_div(max_context_len, partition_size);
13651366
}
1366-
1367+
bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
1368+
#ifdef ENABLE_ONEDNN_FOR_GPU
1369+
can_use_micro_sdpa &= has_stage(pa_sdpa_micro);
1370+
#endif
13671371
GPU_DEBUG_TRACE_DETAIL << "get_internal_buffer_descs: stage = " << static_cast<size_t>(stage) << std::endl;
13681372
int64_t paged_attention_aligned_seq_len = -1;
13691373
if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !params.is_dynamic()) {
1370-
paged_attention_aligned_seq_len = get_aligned_seq_len(params, stage);
1374+
auto block_size = get_query_block_size(stage, can_use_micro_sdpa);
1375+
paged_attention_aligned_seq_len = get_aligned_seq_len(params, stage, block_size);
13711376
}
13721377
const auto target_seq_len = std::max<int64_t>(paged_attention_aligned_seq_len, 1);
13731378
const auto indexes_buf_size = static_cast<int64_t>(ceil_div(target_seq_len, target_seq_len_block_size)) * element_size;
@@ -1431,10 +1436,6 @@ class PagedAttentionOptImpl : public SDPAImplBase {
14311436
}
14321437
}
14331438

1434-
bool can_use_micro_sdpa = stage == PagedAttentionStage::PREFILL;
1435-
#ifdef ENABLE_ONEDNN_FOR_GPU
1436-
can_use_micro_sdpa &= has_stage(pa_sdpa_micro);
1437-
#endif
14381439
if (!can_use_micro_sdpa) {
14391440
// GENERATE/MIXED stages and PREFILL stage without micro_sdpa
14401441
internal_buffers.emplace_back(buf_elements_count * element_size, indexes_dt); // 5: softmax exp_sums
@@ -1610,7 +1611,6 @@ class PagedAttentionOptImpl : public SDPAImplBase {
16101611
const auto block_size = static_cast<int>(query_block_size);
16111612
for (int32_t j = 0; j < seq_length; j += block_size) {
16121613
auto block_start_pos = subsequence_begins_mem_lock[i] + j;
1613-
16141614
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = block_start_pos;
16151615
micro_sdpa_block_starts_and_gws_mapping_lock->operator[](micro_sdpa_index++) = static_cast<int32_t>(i);
16161616
}

0 commit comments

Comments
 (0)