Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
d290f17
Initial changes
turneram Feb 5, 2025
9a43f78
Fix invalid graph
turneram Feb 6, 2025
2243f8c
Init
turneram Feb 10, 2025
c4480d5
Fix benchmarking fault
turneram Feb 14, 2025
2842082
Test
turneram Mar 4, 2025
df1e354
Fix sl>1
turneram Mar 5, 2025
1e4d065
Test
turneram Mar 5, 2025
ed53a45
Additional tests
turneram Mar 10, 2025
00930d2
Update mlir commit
turneram Mar 10, 2025
86885a2
Use int32 in test
turneram Mar 10, 2025
81a1fc1
Remove debug print
turneram Mar 14, 2025
9211184
Make concat_past_present work on k,v separately
turneram Mar 20, 2025
6e56851
Use migx ops instead of mlir fusion
turneram Mar 21, 2025
dc72adb
Use inputs(5)
turneram Mar 30, 2025
36b7fab
Update mlir commit hash
turneram Mar 31, 2025
5219097
Clean up and make prompt test use inputs that reflects e2e data
turneram Mar 31, 2025
ed35af2
Formatting
turneram Mar 31, 2025
f850e6b
Update tests
turneram Mar 31, 2025
c1ce286
Merge remote-tracking branch 'origin/develop' into mlir-kv-cache
turneram Mar 31, 2025
38951a7
Formatting
turneram Mar 31, 2025
362cd24
Add counter and supporting differing num_heads
turneram May 5, 2025
ed70b01
Merge remote-tracking branch 'origin/develop' into mlir-kv-cache
turneram May 5, 2025
4104106
Merge branch 'develop' into mlir-kv-cache
apwojcik May 12, 2025
9e0c1b9
Use greater and remove greater_or_equal
turneram May 13, 2025
c375b7e
Merge remote-tracking branch 'origin/mlir-kv-cache' into mlir-kv-cache
turneram May 13, 2025
1c968fe
Support new mlir causal masking and update tests
turneram May 22, 2025
eb1502a
Formatting
turneram May 22, 2025
e2484af
License stamp
turneram May 22, 2025
24e7c17
Merge remote-tracking branch 'origin/develop' into mlir-kv-cache
turneram May 22, 2025
1429641
Remove test with issue to see if error is specific to this case
turneram Jun 2, 2025
c5817d3
Same thing; next test
turneram Jun 3, 2025
13576b6
Add back tests and add trace_eval
turneram Jun 3, 2025
0bcc470
Remove trace_eval
turneram Jun 9, 2025
c5451d9
Merge branch 'develop' into mlir-kv-cache
apwojcik Jun 11, 2025
5b15541
fix group query attention tests compilation on Windows
apwojcik Jun 11, 2025
2e18f42
Add group op to use with fused attention
turneram Jun 11, 2025
9c47226
Rename tests
turneram Jun 11, 2025
00b5222
Merge branch 'mlir-kv-cache' of https://github.com/ROCm/AMDMIGraphX i…
turneram Jun 11, 2025
70ac554
Merge remote-tracking branch 'origin/develop' into mlir-kv-cache
turneram Jun 11, 2025
8096832
Remove kv_cache_attention prefuse op
turneram Jun 11, 2025
8932831
Update tests names within files
turneram Jun 11, 2025
1867ecf
Tidy
turneram Jun 12, 2025
a7ab512
Formatting
turneram Jun 12, 2025
b6ea84a
Formatting
turneram Jun 12, 2025
7fdc47c
Add group shape test
turneram Jun 12, 2025
f944bc9
Formatting
turneram Jun 12, 2025
1029825
Update set_fill_map to look for greater instead of greater_or_equal
turneram Jun 13, 2025
13bd253
fix cppcheck for Windows parts
apwojcik Jun 13, 2025
19d28bc
remove unused macros
apwojcik Jun 16, 2025
8c41244
Avoid calling as_standard on tuple_type
turneram Jun 16, 2025
6a464e7
Formatting
turneram Jun 16, 2025
779a436
Merge branch 'mlir-kv-cache' of https://github.com/ROCm/AMDMIGraphX i…
turneram Jun 16, 2025
ace7cbd
Merge remote-tracking branch 'origin/develop' into mlir-kv-cache
turneram Jun 16, 2025
9ee8f74
Add multi-output case to group_op test
turneram Jun 16, 2025
d55bd61
Formatting
turneram Jun 16, 2025
18a2228
Merge branch 'develop' into mlir-kv-cache
apwojcik Jun 17, 2025
c7e9fcb
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Jul 22, 2025
3da1159
Initial changes
turneram Sep 16, 2025
74cb33c
Formatting
turneram Sep 16, 2025
3302493
Update mlir commit
turneram Oct 7, 2025
f7ce9b2
Cleanup
turneram Oct 7, 2025
e9a1963
Formatting
turneram Oct 7, 2025
9c6d096
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 7, 2025
e74aa13
Update matcher for fp32 softmax
turneram Oct 8, 2025
ed7e1b8
Formatting
turneram Oct 8, 2025
c00a45b
Cleanup and verify tests
turneram Oct 15, 2025
cfb3641
Formatting
turneram Oct 15, 2025
51891a4
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 15, 2025
4e7cc40
Formatting
turneram Oct 15, 2025
e06f3cd
More tests
turneram Oct 15, 2025
88ff18f
Formatting
turneram Oct 15, 2025
33a3acd
Formatting
turneram Oct 15, 2025
e9d7257
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 16, 2025
087c049
Add local attention support
turneram Oct 16, 2025
ec0747c
Formatting
turneram Oct 16, 2025
0eaa96c
Add onnx verify tests
turneram Oct 17, 2025
cc0a2fe
Formatting
turneram Oct 17, 2025
bd74ab9
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 17, 2025
5aa7cf6
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 22, 2025
d13f81c
Make generate on 7.0
turneram Oct 24, 2025
28aeebf
Fix typo
turneram Oct 24, 2025
a101ee8
Add eof newline
turneram Oct 24, 2025
7198169
Formatting
turneram Oct 24, 2025
63f111a
Licensing
turneram Oct 24, 2025
604a373
Fix some tests
turneram Oct 27, 2025
8ac0469
Formatting
turneram Oct 27, 2025
dcb3a4f
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 27, 2025
fbe05d6
Fix fuse_attention tests
turneram Oct 27, 2025
8ccb3e9
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 27, 2025
65367b3
Formatting
turneram Oct 27, 2025
39eb8c6
Tidy, cppcheck, and copilot
turneram Oct 28, 2025
91af07b
Formatting
turneram Oct 28, 2025
2fd814e
Merge remote-tracking branch 'origin/develop' into refactor-gqa
turneram Oct 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5
[email protected] -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@fe6da4db4d6f0da8c74e28a0787cfbb4a026550a -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
ROCm/rocMLIR@0100c11941426b7ad6f0724d51025fa33d227821 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ register_migraphx_ops(
ceil
clip
concat
concat_past_present
contiguous
convert
convolution
Expand All @@ -212,6 +213,7 @@ register_migraphx_ops(
gather
gathernd
get_tuple_elem
gqa_rotary_embedding
greater
group_query_attention
group
Expand Down
179 changes: 178 additions & 1 deletion src/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,191 @@
}
};

struct find_kv_cache_attention
{
std::size_t* counter;

auto matcher() const
{
static const std::unordered_set<std::string> skip_set = {
"multibroadcast", "reshape", "unsqueeze"};

auto transpose1 = match::skip(match::name(skip_set))(match::name("transpose")(
match::arg(0)(match::skip(match::name(skip_set))(match::name("concat_past_present"))
.bind("pres_k"))));
auto gemm1 =
match::name("dot")(match::arg(0)(match::name("slice")), match::arg(1)(transpose1));
auto scale = match::name("mul")(match::any_arg(0, 1)(gemm1));
auto broadcasted_const = match::name("multibroadcast")(match::arg(0)(match::is_constant()));
auto attn_scores = match::any_of(scale, gemm1);
auto causal_mask =
match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores));
auto greater = match::name("multibroadcast")(match::arg(0)(match::name("convert")(
match::arg(0)(match::name("greater")(match::arg(1)(match::any().bind("total_sl")))))));

Check warning on line 232 in src/fuse_attention.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Too many nested parentheses can affect readability; consider using variables instead. [migraphx-MatcherNestedParentheses]
auto where = match::name("where")(match::arg(0)(greater),
match::arg(2)(match::any_of(causal_mask, scale, gemm1)));
auto softmax = match::skip(match::name("convert"))(
match::softmax_input(match::skip(match::name("convert"))(where)));
auto gemm2 = match::name("dot")(
match::arg(0)(softmax),
match::arg(1)(match::skip(match::name(skip_set))(match::name("concat_past_present"))
.bind("pres_v")));
auto transpose2 = match::name("transpose")(match::arg(0)(gemm2));
return match::name("reshape")(match::arg(0)(transpose2));
}

std::string get_count() const { return std::to_string((*counter)++); }

std::unordered_map<instruction_ref, instruction_ref>
invert_map_ins(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
{
std::unordered_map<instruction_ref, instruction_ref> inverse_map;
for(auto const& [key, value] : map_ins)
{
assert(not contains(inverse_map, value));
inverse_map[value] = key;
}
return inverse_map;
}

std::vector<instruction_ref>
get_attn_instructions(module& m, instruction_ref start, instruction_ref end) const
{
std::queue<instruction_ref> inputs;
std::unordered_set<instruction_ref> inss;
inputs.push(end);

static const std::unordered_set<std::string> valid_attn_ops = {"softmax",
"broadcast",
"dot",
"slice",
"transpose",
"greater",
"convert",
"where",
"reshape",
"reduce_sum",
"reduce_max",
"broadcast",
"multibroadcast",
"@literal",
"unsqueeze"};

auto is_valid_attn_op = [&](auto i) {
return i->get_operator().attributes().get("pointwise", false) or
contains(valid_attn_ops, i->get_operator().name()) or i == start or i == end;
};

while(not inputs.empty())
{
auto current_inp = inputs.front();
inputs.pop();

if(is_valid_attn_op(current_inp) and inss.insert(current_inp).second and
current_inp != start)
{
for(auto i : current_inp->inputs())
{
inputs.push(i);
}
}
}
std::vector<instruction_ref> sorted_inss(inss.begin(), inss.end());
std::sort(
sorted_inss.begin(), sorted_inss.end(), [&](instruction_ref x, instruction_ref y) {
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
});
return sorted_inss;
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto total_sl = r.instructions["total_sl"];
auto reshape = r.result;

// Capture all instructions part of the attention op
auto attn_inss = get_attn_instructions(mpm.get_module(), total_sl, reshape);

// Add captured instructions to new submodule
module m_attn;
std::unordered_map<instruction_ref, instruction_ref> map_mm_to_mattn;
auto attn_outs = m_attn.fuse(attn_inss, &map_mm_to_mattn);

for(auto ins : iterator_for(m_attn))
{
if(ins->can_eval())
{
auto lit_s = ins->get_shape();
auto strides = lit_s.strides();
if(strides.size() == 4 and
std::all_of(
strides.begin(), strides.end() - 1, [](auto s) { return s == 0; }) and
strides.back() == 1)
{
auto new_lit = m_attn.add_literal(
literal{shape{lit_s.type(), {lit_s.lens().back()}}, ins->eval().data()});
m_attn.replace_instruction(
ins, make_op("multibroadcast", {{"out_lens", lit_s.lens()}}), {new_lit});
}
}
}
dead_code_elimination{}.apply(m_attn);

// Define outputs based on instructions that are used elsewhere in the graph
std::vector<instruction_ref> required_outputs;
std::copy_if(
attn_inss.begin(), attn_inss.end(), std::back_inserter(required_outputs), [&](auto i) {
return not std::all_of(i->outputs().begin(), i->outputs().end(), [&](auto o) {
return contains(attn_inss, o);
});
});

assert(not required_outputs.empty());

// Find corresponding output instructions in m_attn
std::vector<instruction_ref> m_attn_outputs;
std::transform(required_outputs.begin(),
required_outputs.end(),
std::back_inserter(m_attn_outputs),
[&](auto i) { return map_mm_to_mattn.at(i); });
m_attn.add_return({m_attn_outputs.back()});

// Define inputs to m_attn
auto map_mattn_to_mm = invert_map_ins(map_mm_to_mattn);
auto new_inputs = m_attn.get_inputs(map_mattn_to_mm);

module_ref mpm_attn = mpm.create_module("attn" + get_count(), std::move(m_attn));
mpm_attn->set_bypass();

// Construct group op with the attention module
auto group_ins =
mpm.get_module().insert_instruction(required_outputs.back(),
make_op("group", {{"tag", "kv_cache_attention"}}),
new_inputs,
{mpm_attn});

mpm.get_module().replace_instruction(required_outputs.back(), group_ins);
}
};

} // namespace

void fuse_attention::apply(module_pass_manager& mpm) const
{
std::size_t counter = 0;
match::find_matches(mpm, find_attention{.counter = &counter});

// Fuse kv-cache attention by default
match::find_matches(mpm, find_kv_cache_attention{.counter = &counter});
mpm.get_module().sort();
mpm.run_pass(dead_code_elimination{});

// Only fuse plain attention when requested
if(attn_enabled)
{
match::find_matches(mpm, find_attention{.counter = &counter});
mpm.get_module().sort();
mpm.run_pass(dead_code_elimination{});
}
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/fuse_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct module_pass_manager;

struct MIGRAPHX_EXPORT fuse_attention
{
bool attn_enabled = false;

std::string name() const { return "fuse_attention"; }
void apply(module_pass_manager& mpm) const;
};
Expand Down
161 changes: 161 additions & 0 deletions src/include/migraphx/op/concat_past_present.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_CONCAT_PAST_PRESENT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_PAST_PRESENT_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/argument.hpp>
#include <fstream>
#include <iostream>
#include <iomanip>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct cache_parameters
{
std::size_t batch_size = 0; // Batch size used by input
std::size_t sequence_length = 0; // Sequence length used by input
std::size_t head_size = 0; // Head size
std::size_t num_heads = 0; // num_heads = hidden_size / head_size
std::size_t seqlen_present_kv_cache = 0; // Sequence length of present kv-cache
};

struct concat_past_present
{
std::size_t kv_num_heads = 0;
std::size_t num_heads = 1;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.kv_num_heads, "kv_num_heads"), f(self.num_heads, "num_heads"));
}

std::string name() const { return "concat_past_present"; }

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(3);
return inputs.back();
}

template <class T>
void copy_data(T destination, const T source, std::size_t n) const
{
par_for(n, [&](auto i) { destination[i] = source[i]; });
}

template <typename T>
T concat_state_chunk(const T chunk,
const T present,
std::size_t present_buff_chunk_length,
std::size_t past_chunk_length,
std::size_t new_chunk_length,
std::ptrdiff_t i) const
{
T start = present + i * present_buff_chunk_length;
copy_data(start + past_chunk_length, chunk, new_chunk_length);
return start;
}

template <class T, class U>
void
update_cache(T past_key, const U seqlens_k, const T present_key, cache_parameters params) const
{
const std::size_t batch_size = params.batch_size;
const std::size_t sequence_length = params.sequence_length;
const std::size_t head_size = params.head_size;
const std::size_t past_buffer_sequence_length = params.seqlen_present_kv_cache;
const std::size_t present_buffer_sequence_length = past_buffer_sequence_length;

const bool is_prompt = sequence_length != 1;
const std::size_t packed_batch_stride =
(num_heads + 2 * kv_num_heads) * sequence_length * head_size;
const std::size_t kv_num_heads_factor = num_heads / kv_num_heads;
const std::size_t kv_input_chunk_length = sequence_length * head_size; // L x H
const std::size_t present_buff_chunk_length =
present_buffer_sequence_length * head_size; // T x H

const std::size_t loop_len = batch_size * num_heads;

par_for(loop_len, [&](const auto i) {
const std::size_t batch_index = i / num_heads;
const std::size_t head_index = i % num_heads;
const std::size_t past_seqlen =
sequence_length == 1 ? seqlens_k[batch_index] : past_buffer_sequence_length;
const std::size_t past_chunk_length = is_prompt ? 0 : past_seqlen * head_size;

auto current = present_key + packed_batch_stride * batch_index +
kv_input_chunk_length * (head_index / kv_num_heads_factor);
concat_state_chunk(current,
past_key,
present_buff_chunk_length,
past_chunk_length,
kv_input_chunk_length,
i / kv_num_heads_factor);
});
}

argument compute(const shape& /* output_shape */, std::vector<argument> args) const
{
auto present = args[0];
auto seqlens = args[1];
auto past = args[2];
auto present_shape = present.get_shape();
const auto& present_lens = present_shape.lens();
const std::size_t batch_size = present_lens[0];
const std::size_t sequence_length = present_lens[2];
auto past_kv_shape = past.get_shape();
const auto& past_kv_lens = past_kv_shape.lens();
auto past_sequence_length = past_kv_lens[2];
std::size_t head_size = present_lens[3];

cache_parameters cache_params = {};
cache_params.batch_size = batch_size;
cache_params.sequence_length = sequence_length;
cache_params.head_size = head_size;
cache_params.num_heads = num_heads;
cache_params.seqlen_present_kv_cache = past_sequence_length;

visit_all(past, present)([&](auto past_kv, auto present_kv) {
visit_all(seqlens)([&](auto seqlens_kv) {
update_cache(past_kv.begin(), seqlens_kv.begin(), present_kv.begin(), cache_params);
});
});

return past;
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
Loading
Loading