@@ -209,14 +209,193 @@ struct find_attention
209209 }
210210};
211211
212+ struct find_kv_cache_attention
213+ {
214+ std::size_t * counter;
215+
216+ auto matcher () const
217+ {
218+ static const std::unordered_set<std::string> skip_set = {
219+ " multibroadcast" , " reshape" , " unsqueeze" };
220+
221+ auto keys =
222+ match::skip (match::name (skip_set))(match::name (" concat_past_present" )).bind (" pres_k" );
223+ auto k_transpose =
224+ match::skip (match::name (skip_set))(match::name (" transpose" )(match::arg (0 )(keys)));
225+ auto queries = match::name (" slice" );
226+ auto gemm1 = match::name (" dot" )(match::arg (0 )(queries), match::arg (1 )(k_transpose));
227+ auto scale = match::name (" mul" )(match::any_arg (0 , 1 )(gemm1));
228+ auto broadcasted_const = match::name (" multibroadcast" )(match::arg (0 )(match::is_constant ()));
229+ auto attn_scores = match::any_of (scale, gemm1);
230+ auto causal_mask =
231+ match::name (" where" )(match::arg (0 )(broadcasted_const), match::arg (2 )(attn_scores));
232+ auto greater = match::name (" greater" )(match::arg (1 )(match::any ().bind (" total_sl" )));
233+ auto conv_greater =
234+ match::skip (match::name (" unsqueeze" ))(match::name (" convert" )(match::arg (0 )(greater)));
235+ auto bc_greater = match::name (" multibroadcast" )(match::arg (0 )(conv_greater));
236+ auto mask = match::name (" where" )(match::arg (0 )(bc_greater),
237+ match::arg (2 )(match::any_of (causal_mask, scale, gemm1)));
238+ auto attn_probabilities = match::skip (match::name (" convert" ))(
239+ match::softmax_input (match::skip (match::name (" convert" ))(mask)));
240+ auto values =
241+ match::skip (match::name (skip_set))(match::name (" concat_past_present" )).bind (" pres_v" );
242+ auto gemm2 = match::name (" dot" )(match::arg (0 )(attn_probabilities), match::arg (1 )(values));
243+ auto transpose_out = match::name (" transpose" )(match::arg (0 )(gemm2));
244+ return match::name (" reshape" )(match::arg (0 )(transpose_out));
245+ }
246+
247+ std::string get_count () const { return std::to_string ((*counter)++); }
248+
249+ std::unordered_map<instruction_ref, instruction_ref>
250+ invert_map_ins (const std::unordered_map<instruction_ref, instruction_ref>& map_ins) const
251+ {
252+ std::unordered_map<instruction_ref, instruction_ref> inverse_map;
253+ for (auto const & [key, value] : map_ins)
254+ {
255+ assert (not contains (inverse_map, value));
256+ inverse_map[value] = key;
257+ }
258+ return inverse_map;
259+ }
260+
261+ std::vector<instruction_ref>
262+ get_attn_instructions (module & m, instruction_ref start, instruction_ref end) const
263+ {
264+ std::queue<instruction_ref> inputs;
265+ std::unordered_set<instruction_ref> inss;
266+ inputs.push (end);
267+
268+ static const std::unordered_set<std::string> valid_attn_ops = {" softmax" ,
269+ " broadcast" ,
270+ " dot" ,
271+ " slice" ,
272+ " transpose" ,
273+ " greater" ,
274+ " convert" ,
275+ " where" ,
276+ " reshape" ,
277+ " reduce_sum" ,
278+ " reduce_max" ,
279+ " broadcast" ,
280+ " multibroadcast" ,
281+ " @literal" ,
282+ " unsqueeze" };
283+
284+ auto is_valid_attn_op = [&](auto i) {
285+ return i->get_operator ().attributes ().get (" pointwise" , false ) or
286+ contains (valid_attn_ops, i->get_operator ().name ()) or i == start or i == end;
287+ };
288+
289+ while (not inputs.empty ())
290+ {
291+ auto current_inp = inputs.front ();
292+ inputs.pop ();
293+
294+ if (is_valid_attn_op (current_inp) and inss.insert (current_inp).second and
295+ current_inp != start)
296+ {
297+ for (auto i : current_inp->inputs ())
298+ {
299+ inputs.push (i);
300+ }
301+ }
302+ }
303+ std::vector<instruction_ref> sorted_inss (inss.begin (), inss.end ());
304+ std::sort (
305+ sorted_inss.begin (), sorted_inss.end (), [&](instruction_ref x, instruction_ref y) {
306+ return std::distance (m.begin (), x) < std::distance (m.begin (), y);
307+ });
308+ return sorted_inss;
309+ }
310+
311+ void apply (module_pass_manager& mpm, const match::matcher_result& r) const
312+ {
313+ auto total_sl = r.instructions [" total_sl" ];
314+ auto reshape = r.result ;
315+
316+ // Capture all instructions part of the attention op
317+ auto attn_inss = get_attn_instructions (mpm.get_module (), total_sl, reshape);
318+
319+ // Add captured instructions to new submodule
320+ module m_attn;
321+ std::unordered_map<instruction_ref, instruction_ref> map_mm_to_mattn;
322+ auto attn_outs = m_attn.fuse (attn_inss, &map_mm_to_mattn);
323+
324+ for (auto ins : iterator_for (m_attn))
325+ {
326+ if (ins->can_eval ())
327+ {
328+ auto lit_s = ins->get_shape ();
329+ auto strides = lit_s.strides ();
330+ if (strides.size () == 4 and
331+ std::all_of (
332+ strides.begin (), strides.end () - 1 , [](auto s) { return s == 0 ; }) and
333+ strides.back () == 1 )
334+ {
335+ auto new_lit = m_attn.add_literal (
336+ literal{shape{lit_s.type (), {lit_s.lens ().back ()}}, ins->eval ().data ()});
337+ m_attn.replace_instruction (
338+ ins, make_op (" multibroadcast" , {{" out_lens" , lit_s.lens ()}}), {new_lit});
339+ }
340+ }
341+ }
342+ dead_code_elimination{}.apply (m_attn);
343+
344+ // Define outputs based on instructions that are used elsewhere in the graph
345+ std::vector<instruction_ref> required_outputs;
346+ std::copy_if (
347+ attn_inss.begin (), attn_inss.end (), std::back_inserter (required_outputs), [&](auto i) {
348+ return not std::all_of (i->outputs ().begin (), i->outputs ().end (), [&](auto o) {
349+ return contains (attn_inss, o);
350+ });
351+ });
352+
353+ assert (not required_outputs.empty ());
354+
355+ // Find corresponding output instructions in m_attn
356+ std::vector<instruction_ref> m_attn_outputs;
357+ std::transform (required_outputs.begin (),
358+ required_outputs.end (),
359+ std::back_inserter (m_attn_outputs),
360+ [&](auto i) { return map_mm_to_mattn.at (i); });
361+ m_attn.add_return ({m_attn_outputs.back ()});
362+
363+ // Define inputs to m_attn
364+ auto map_mattn_to_mm = invert_map_ins (map_mm_to_mattn);
365+ auto new_inputs = m_attn.get_inputs (map_mattn_to_mm);
366+
367+ module_ref mpm_attn = mpm.create_module (" attn" + get_count (), std::move (m_attn));
368+ mpm_attn->set_bypass ();
369+
370+ // Construct group op with the attention module
371+ auto group_ins =
372+ mpm.get_module ().insert_instruction (required_outputs.back (),
373+ make_op (" group" , {{" tag" , " kv_cache_attention" }}),
374+ new_inputs,
375+ {mpm_attn});
376+
377+ mpm.get_module ().replace_instruction (required_outputs.back (), group_ins);
378+ }
379+ };
380+
212381} // namespace
213382
214383void fuse_attention::apply (module_pass_manager& mpm) const
215384{
216385 std::size_t counter = 0 ;
217- match::find_matches (mpm, find_attention{.counter = &counter});
386+
387+ // Fuse kv-cache attention by default
388+ match::find_matches (mpm, find_kv_cache_attention{.counter = &counter});
218389 mpm.get_module ().sort ();
219390 mpm.run_pass (dead_code_elimination{});
391+
392+ // Only fuse plain attention when requested
393+ if (attn_enabled)
394+ {
395+ match::find_matches (mpm, find_attention{.counter = &counter});
396+ mpm.get_module ().sort ();
397+ mpm.run_pass (dead_code_elimination{});
398+ }
220399}
221400
222401} // namespace MIGRAPHX_INLINE_NS
0 commit comments