Skip to content

Commit d80331a

Browse files
authored
Refactor GroupQueryAttention (#4396)
Remove the GroupQueryAttention ref op and use equivalent ops in its place.
1 parent 0d81fd4 commit d80331a

File tree

45 files changed

+3488
-1810
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+3488
-1810
lines changed

src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ register_migraphx_ops(
190190
ceil
191191
clip
192192
concat
193+
concat_past_present
193194
contiguous
194195
convert
195196
convolution
@@ -212,8 +213,8 @@ register_migraphx_ops(
212213
gather
213214
gathernd
214215
get_tuple_elem
216+
gqa_rotary_embedding
215217
greater
216-
group_query_attention
217218
group
218219
gru
219220
identity

src/fuse_attention.cpp

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,193 @@ struct find_attention
209209
}
210210
};
211211

212+
struct find_kv_cache_attention
213+
{
214+
std::size_t* counter;
215+
216+
auto matcher() const
217+
{
218+
static const std::unordered_set<std::string> skip_set = {
219+
"multibroadcast", "reshape", "unsqueeze"};
220+
221+
auto keys =
222+
match::skip(match::name(skip_set))(match::name("concat_past_present")).bind("pres_k");
223+
auto k_transpose =
224+
match::skip(match::name(skip_set))(match::name("transpose")(match::arg(0)(keys)));
225+
auto queries = match::name("slice");
226+
auto gemm1 = match::name("dot")(match::arg(0)(queries), match::arg(1)(k_transpose));
227+
auto scale = match::name("mul")(match::any_arg(0, 1)(gemm1));
228+
auto broadcasted_const = match::name("multibroadcast")(match::arg(0)(match::is_constant()));
229+
auto attn_scores = match::any_of(scale, gemm1);
230+
auto causal_mask =
231+
match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores));
232+
auto greater = match::name("greater")(match::arg(1)(match::any().bind("total_sl")));
233+
auto conv_greater =
234+
match::skip(match::name("unsqueeze"))(match::name("convert")(match::arg(0)(greater)));
235+
auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater));
236+
auto mask = match::name("where")(match::arg(0)(bc_greater),
237+
match::arg(2)(match::any_of(causal_mask, scale, gemm1)));
238+
auto attn_probabilities = match::skip(match::name("convert"))(
239+
match::softmax_input(match::skip(match::name("convert"))(mask)));
240+
auto values =
241+
match::skip(match::name(skip_set))(match::name("concat_past_present")).bind("pres_v");
242+
auto gemm2 = match::name("dot")(match::arg(0)(attn_probabilities), match::arg(1)(values));
243+
auto transpose_out = match::name("transpose")(match::arg(0)(gemm2));
244+
return match::name("reshape")(match::arg(0)(transpose_out));
245+
}
246+
247+
std::string get_count() const { return std::to_string((*counter)++); }
248+
249+
std::unordered_map<instruction_ref, instruction_ref>
250+
invert_map_ins(const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
251+
{
252+
std::unordered_map<instruction_ref, instruction_ref> inverse_map;
253+
for(auto const& [key, value] : map_ins)
254+
{
255+
assert(not contains(inverse_map, value));
256+
inverse_map[value] = key;
257+
}
258+
return inverse_map;
259+
}
260+
261+
std::vector<instruction_ref>
262+
get_attn_instructions(module& m, instruction_ref start, instruction_ref end) const
263+
{
264+
std::queue<instruction_ref> inputs;
265+
std::unordered_set<instruction_ref> inss;
266+
inputs.push(end);
267+
268+
static const std::unordered_set<std::string> valid_attn_ops = {"softmax",
269+
"broadcast",
270+
"dot",
271+
"slice",
272+
"transpose",
273+
"greater",
274+
"convert",
275+
"where",
276+
"reshape",
277+
"reduce_sum",
278+
"reduce_max",
279+
"broadcast",
280+
"multibroadcast",
281+
"@literal",
282+
"unsqueeze"};
283+
284+
auto is_valid_attn_op = [&](auto i) {
285+
return i->get_operator().attributes().get("pointwise", false) or
286+
contains(valid_attn_ops, i->get_operator().name()) or i == start or i == end;
287+
};
288+
289+
while(not inputs.empty())
290+
{
291+
auto current_inp = inputs.front();
292+
inputs.pop();
293+
294+
if(is_valid_attn_op(current_inp) and inss.insert(current_inp).second and
295+
current_inp != start)
296+
{
297+
for(auto i : current_inp->inputs())
298+
{
299+
inputs.push(i);
300+
}
301+
}
302+
}
303+
std::vector<instruction_ref> sorted_inss(inss.begin(), inss.end());
304+
std::sort(
305+
sorted_inss.begin(), sorted_inss.end(), [&](instruction_ref x, instruction_ref y) {
306+
return std::distance(m.begin(), x) < std::distance(m.begin(), y);
307+
});
308+
return sorted_inss;
309+
}
310+
311+
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
312+
{
313+
auto total_sl = r.instructions["total_sl"];
314+
auto reshape = r.result;
315+
316+
// Capture all instructions part of the attention op
317+
auto attn_inss = get_attn_instructions(mpm.get_module(), total_sl, reshape);
318+
319+
// Add captured instructions to new submodule
320+
module m_attn;
321+
std::unordered_map<instruction_ref, instruction_ref> map_mm_to_mattn;
322+
auto attn_outs = m_attn.fuse(attn_inss, &map_mm_to_mattn);
323+
324+
for(auto ins : iterator_for(m_attn))
325+
{
326+
if(ins->can_eval())
327+
{
328+
auto lit_s = ins->get_shape();
329+
auto strides = lit_s.strides();
330+
if(strides.size() == 4 and
331+
std::all_of(
332+
strides.begin(), strides.end() - 1, [](auto s) { return s == 0; }) and
333+
strides.back() == 1)
334+
{
335+
auto new_lit = m_attn.add_literal(
336+
literal{shape{lit_s.type(), {lit_s.lens().back()}}, ins->eval().data()});
337+
m_attn.replace_instruction(
338+
ins, make_op("multibroadcast", {{"out_lens", lit_s.lens()}}), {new_lit});
339+
}
340+
}
341+
}
342+
dead_code_elimination{}.apply(m_attn);
343+
344+
// Define outputs based on instructions that are used elsewhere in the graph
345+
std::vector<instruction_ref> required_outputs;
346+
std::copy_if(
347+
attn_inss.begin(), attn_inss.end(), std::back_inserter(required_outputs), [&](auto i) {
348+
return not std::all_of(i->outputs().begin(), i->outputs().end(), [&](auto o) {
349+
return contains(attn_inss, o);
350+
});
351+
});
352+
353+
assert(not required_outputs.empty());
354+
355+
// Find corresponding output instructions in m_attn
356+
std::vector<instruction_ref> m_attn_outputs;
357+
std::transform(required_outputs.begin(),
358+
required_outputs.end(),
359+
std::back_inserter(m_attn_outputs),
360+
[&](auto i) { return map_mm_to_mattn.at(i); });
361+
m_attn.add_return({m_attn_outputs.back()});
362+
363+
// Define inputs to m_attn
364+
auto map_mattn_to_mm = invert_map_ins(map_mm_to_mattn);
365+
auto new_inputs = m_attn.get_inputs(map_mattn_to_mm);
366+
367+
module_ref mpm_attn = mpm.create_module("attn" + get_count(), std::move(m_attn));
368+
mpm_attn->set_bypass();
369+
370+
// Construct group op with the attention module
371+
auto group_ins =
372+
mpm.get_module().insert_instruction(required_outputs.back(),
373+
make_op("group", {{"tag", "kv_cache_attention"}}),
374+
new_inputs,
375+
{mpm_attn});
376+
377+
mpm.get_module().replace_instruction(required_outputs.back(), group_ins);
378+
}
379+
};
380+
212381
} // namespace
213382

214383
void fuse_attention::apply(module_pass_manager& mpm) const
215384
{
216385
std::size_t counter = 0;
217-
match::find_matches(mpm, find_attention{.counter = &counter});
386+
387+
// Fuse kv-cache attention by default
388+
match::find_matches(mpm, find_kv_cache_attention{.counter = &counter});
218389
mpm.get_module().sort();
219390
mpm.run_pass(dead_code_elimination{});
391+
392+
// Only fuse plain attention when requested
393+
if(attn_enabled)
394+
{
395+
match::find_matches(mpm, find_attention{.counter = &counter});
396+
mpm.get_module().sort();
397+
mpm.run_pass(dead_code_elimination{});
398+
}
220399
}
221400

222401
} // namespace MIGRAPHX_INLINE_NS

src/include/migraphx/fuse_attention.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ struct module_pass_manager;
3535

3636
struct MIGRAPHX_EXPORT fuse_attention
3737
{
38+
bool attn_enabled = false;
39+
3840
std::string name() const { return "fuse_attention"; }
3941
void apply(module_pass_manager& mpm) const;
4042
};
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*
24+
*/
25+
#ifndef MIGRAPHX_GUARD_OPERATORS_CONCAT_PAST_PRESENT_HPP
26+
#define MIGRAPHX_GUARD_OPERATORS_CONCAT_PAST_PRESENT_HPP
27+
28+
#include <migraphx/check_shapes.hpp>
29+
#include <migraphx/shape_for_each.hpp>
30+
#include <migraphx/par_for.hpp>
31+
#include <migraphx/gemm.hpp>
32+
#include <migraphx/argument.hpp>
33+
#include <fstream>
34+
#include <iostream>
35+
#include <iomanip>
36+
37+
namespace migraphx {
38+
inline namespace MIGRAPHX_INLINE_NS {
39+
namespace op {
40+
41+
struct cache_parameters
42+
{
43+
std::size_t batch_size = 0; // Batch size used by input
44+
std::size_t sequence_length = 0; // Sequence length used by input
45+
std::size_t head_size = 0; // Head size
46+
std::size_t seqlen_present_kv_cache = 0; // Sequence length of present kv-cache
47+
};
48+
49+
struct concat_past_present
50+
{
51+
std::size_t kv_num_heads = 0;
52+
53+
template <class Self, class F>
54+
static auto reflect(Self& self, F f)
55+
{
56+
return pack(f(self.kv_num_heads, "kv_num_heads"));
57+
}
58+
59+
std::string name() const { return "concat_past_present"; }
60+
61+
shape compute_shape(std::vector<shape> inputs) const
62+
{
63+
check_shapes{inputs, *this}.has(3);
64+
return inputs.back();
65+
}
66+
67+
template <class T>
68+
void copy_data(T destination, const T source, std::size_t n) const
69+
{
70+
par_for(n, [&](auto i) { destination[i] = source[i]; });
71+
}
72+
73+
template <typename T>
74+
T concat_state_chunk(const T chunk,
75+
const T present,
76+
std::size_t present_buff_chunk_length,
77+
std::size_t past_chunk_length,
78+
std::size_t new_chunk_length,
79+
std::ptrdiff_t i) const
80+
{
81+
T start = present + i * present_buff_chunk_length;
82+
copy_data(start + past_chunk_length, chunk, new_chunk_length);
83+
return start;
84+
}
85+
86+
template <class T, class U>
87+
void
88+
update_cache(T past_key, const U seqlens_k, const T present_key, cache_parameters params) const
89+
{
90+
const std::size_t batch_size = params.batch_size;
91+
const std::size_t sequence_length = params.sequence_length;
92+
const std::size_t head_size = params.head_size;
93+
const std::size_t past_buffer_sequence_length = params.seqlen_present_kv_cache;
94+
const std::size_t present_buffer_sequence_length = past_buffer_sequence_length;
95+
96+
const bool is_prompt = sequence_length != 1;
97+
const std::size_t packed_batch_stride = kv_num_heads * sequence_length * head_size;
98+
const std::size_t kv_input_chunk_length = sequence_length * head_size; // L x H
99+
const std::size_t present_buff_chunk_length =
100+
present_buffer_sequence_length * head_size; // T x H
101+
102+
const std::size_t loop_len = batch_size * kv_num_heads;
103+
104+
par_for(loop_len, [&](const auto i) {
105+
const std::size_t batch_index = i / kv_num_heads;
106+
const std::size_t head_index = i % kv_num_heads;
107+
const std::size_t past_seqlen =
108+
sequence_length == 1 ? seqlens_k[batch_index] : past_buffer_sequence_length;
109+
const std::size_t past_chunk_length = is_prompt ? 0 : past_seqlen * head_size;
110+
auto current = present_key + packed_batch_stride * batch_index +
111+
kv_input_chunk_length * head_index;
112+
concat_state_chunk(current,
113+
past_key,
114+
present_buff_chunk_length,
115+
past_chunk_length,
116+
kv_input_chunk_length,
117+
i);
118+
});
119+
}
120+
121+
argument compute(const shape& /* output_shape */, std::vector<argument> args) const
122+
{
123+
auto present = args[0];
124+
auto seqlens = args[1];
125+
auto past = args[2];
126+
auto present_shape = present.get_shape();
127+
const auto& present_lens = present_shape.lens();
128+
const std::size_t batch_size = present_lens[0];
129+
const std::size_t sequence_length = present_lens[2];
130+
auto past_kv_shape = past.get_shape();
131+
const auto& past_kv_lens = past_kv_shape.lens();
132+
auto past_sequence_length = past_kv_lens[2];
133+
std::size_t head_size = present_lens[3];
134+
135+
cache_parameters cache_params = {};
136+
cache_params.batch_size = batch_size;
137+
cache_params.sequence_length = sequence_length;
138+
cache_params.head_size = head_size;
139+
cache_params.seqlen_present_kv_cache = past_sequence_length;
140+
141+
visit_all(past, present)([&](auto past_kv, auto present_kv) {
142+
visit_all(seqlens)([&](auto seqlens_kv) {
143+
update_cache(past_kv.begin(), seqlens_kv.begin(), present_kv.begin(), cache_params);
144+
});
145+
});
146+
147+
return past;
148+
}
149+
};
150+
151+
} // namespace op
152+
} // namespace MIGRAPHX_INLINE_NS
153+
} // namespace migraphx
154+
155+
#endif

0 commit comments

Comments
 (0)