Skip to content

Commit 942e080

Browse files
authored
[fix] Fix missing fields in xqa kernel cache key (#6282)
Signed-off-by: Yao Yao <[email protected]>
1 parent fbee279 commit 942e080

File tree

4 files changed

+11
-6
lines changed

4 files changed

+11
-6
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
5555
// precompiled XQA does not use is_fp8_output as hashing key
5656
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
5757
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
58-
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false};
58+
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false,
59+
isXqaJit ? std::optional(xqaParams.position_embedding_type) : std::nullopt};
5960
}
6061

6162
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ struct XQAKernelRuntimeHashKey
6767
bool paged_kv_cache;
6868
bool multi_query_tokens;
6969
bool is_fp8_output;
70+
std::optional<PositionEmbeddingType> position_embedding_type;
7071

7172
bool operator==(XQAKernelRuntimeHashKey const& other) const
7273
{
7374
return kv_data_type == other.kv_data_type && head_size == other.head_size
7475
&& num_q_heads_per_kv == other.num_q_heads_per_kv && beam_size == other.beam_size
7576
&& multi_query_tokens == other.multi_query_tokens && m_tilesize == other.m_tilesize
7677
&& tokens_per_page == other.tokens_per_page && paged_kv_cache == other.paged_kv_cache
77-
&& is_fp8_output == other.is_fp8_output;
78+
&& is_fp8_output == other.is_fp8_output && position_embedding_type == other.position_embedding_type;
7879
}
7980
};
8081

@@ -103,6 +104,8 @@ struct XQAKernelRuntimeHasher
103104
key ^= s.multi_query_tokens;
104105
key <<= 1;
105106
key ^= s.is_fp8_output;
107+
key <<= 8;
108+
key ^= static_cast<int8_t>(s.position_embedding_type.value_or(static_cast<PositionEmbeddingType>(-1)));
106109
return key;
107110
}
108111
};

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ using ::tensorrt_llm::kernels::XQAKernelMetaInfo;
3737
XQAKernelRuntimeHashKey getRuntimeHashKeyFromKernelMeta(XQAKernelMetaInfo const& kernelMeta)
3838
{
3939
return {kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth, kernelMeta.mNumQHeadsOverKV,
40-
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens,
41-
0 /* xqa jit param is_fp8_output */};
40+
kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens, false,
41+
std::nullopt};
4242
}
4343

4444
} // anonymous namespace

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class XQAKernelList
9797
}
9898
XQAKernelRuntimeHashKey hash_key{kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth,
9999
kernelMeta.mNumQHeadsOverKV, kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache,
100-
kernelMeta.mMultiQueryTokens, 0 /* xqa jit param is_fp8_output */};
100+
kernelMeta.mMultiQueryTokens, false, std::nullopt};
101101

102102
mFunctions.insert(std::make_pair(hash_key, funcInfo));
103103
}
@@ -128,7 +128,8 @@ class XQAKernelList
128128
XQAKernelRuntimeHashKey hash_key
129129
= {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
130130
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
131-
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0 /* xqa jit param is_fp8_output */};
131+
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0, /* xqa jit param is_fp8_output */
132+
std::nullopt};
132133
auto const findIter = mFunctions.find(hash_key);
133134
return findIter != mFunctions.end();
134135
}

0 commit comments

Comments
 (0)