Skip to content

Commit 66b6e36

Browse files
authored
Flash decoding round 1; AIMIGRAPHX-242 (#4393)
rocMLIR has implemented split kv and GQA, which enables us to implement flash decoding.
1 parent d80331a commit 66b6e36

9 files changed

+1239
-8
lines changed

docs/reference/MIGraphX-dev-env-vars.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ Model performance tunable variables change the compilation behavior of a model.
160160
161161
| Default: Split-k performance configurations are turned off.
162162
163+
* - | ``MIGRAPHX_FLASH_DECODING_NUM_SPLITS``
164+
| Turns on flash decoding for attention fusion and sets the number of splits along the key-value sequence dimension.
165+
166+
- | ``0``: Flash decoding is turned off (i.e., number of splits is 0).
167+
| ``N`` (where N > 1): Enables flash decoding with N splits along the key-value sequence dimension. For example, ``2`` enables flash decoding with 2 splits, ``4`` with 4 splits, etc.
168+
169+
| Default: flash decoding is turned off.
170+
163171
* - | ``MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT``
164172
| When set, FP16 is not converted to FP32 in the ``InstanceNormalization`` ONNX operator.
165173

src/fuse_attention.cpp

Lines changed: 417 additions & 0 deletions
Large diffs are not rendered by default.

src/include/migraphx/fuse_attention.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
#include <migraphx/config.hpp>
2929
#include <string>
30+
#include <optional>
31+
#include <cstddef>
3032

3133
namespace migraphx {
3234
inline namespace MIGRAPHX_INLINE_NS {
@@ -35,6 +37,7 @@ struct module_pass_manager;
3537

3638
struct MIGRAPHX_EXPORT fuse_attention
3739
{
40+
std::optional<std::size_t> flash_decoding_num_splits = std::nullopt;
3841
bool attn_enabled = false;
3942

4043
std::string name() const { return "fuse_attention"; }

test/fuse_attention.cpp

Lines changed: 290 additions & 8 deletions
Large diffs are not rendered by default.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include "verify_program.hpp"
26+
#include <migraphx/program.hpp>
27+
#include <migraphx/generate.hpp>
28+
#include <migraphx/make_op.hpp>
29+
30+
template <migraphx::shape::type_t DType>
31+
struct test_attention_flash_decoding_3d : verify_program<test_attention_flash_decoding_3d<DType>>
32+
{
33+
migraphx::program create_program() const
34+
{
35+
// 3D Shape: [batch, sequence_length, head_dim]
36+
migraphx::shape s_3d{DType, {1, 256, 256}};
37+
38+
migraphx::program p1;
39+
auto* mm = p1.get_main_module();
40+
auto a = mm->add_parameter("q", s_3d);
41+
auto b = mm->add_parameter("k", s_3d);
42+
auto b1 = mm->add_parameter("v", s_3d);
43+
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b);
44+
b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1);
45+
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
46+
auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1);
47+
rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
48+
rmax);
49+
auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax);
50+
auto exp = mm->add_instruction(migraphx::make_op("exp"), sub);
51+
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp);
52+
rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
53+
rsum);
54+
auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum);
55+
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1);
56+
mm->add_return({gemm2});
57+
return p1;
58+
}
59+
};
60+
61+
// These tests are not run by default currently; the env vars below need to be set:
62+
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
63+
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
64+
template struct test_attention_flash_decoding_3d<migraphx::shape::half_type>;
65+
template struct test_attention_flash_decoding_3d<migraphx::shape::bf16_type>;
66+
template struct test_attention_flash_decoding_3d<migraphx::shape::float_type>;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include "verify_program.hpp"
26+
#include <migraphx/program.hpp>
27+
#include <migraphx/generate.hpp>
28+
#include <migraphx/make_op.hpp>
29+
30+
template <migraphx::shape::type_t DType>
31+
struct test_attention_flash_decoding_3d_input_fusion
32+
: verify_program<test_attention_flash_decoding_3d_input_fusion<DType>>
33+
{
34+
migraphx::program create_program() const
35+
{
36+
// 3D Shape: [batch, sequence_length, head_dim]
37+
migraphx::shape s_3d{DType, {1, 256, 256}};
38+
39+
migraphx::program p1;
40+
auto* mm = p1.get_main_module();
41+
42+
// Input parameters
43+
auto q_input = mm->add_parameter("q", s_3d);
44+
auto k_input = mm->add_parameter("k", s_3d);
45+
auto v_input = mm->add_parameter("v", s_3d);
46+
47+
// Bias parameters for input fusion
48+
auto q_bias = mm->add_parameter("q_bias", s_3d);
49+
auto k_bias = mm->add_parameter("k_bias", s_3d);
50+
auto v_bias = mm->add_parameter("v_bias", s_3d);
51+
52+
// Scale parameter (typically 1/sqrt(head_dim))
53+
migraphx::shape scale_shape{DType, {1}};
54+
auto scale = mm->add_parameter("scale", scale_shape);
55+
56+
// Input fusion operations
57+
// Add bias to Q, K, V
58+
auto q_with_bias = mm->add_instruction(migraphx::make_op("add"), q_input, q_bias);
59+
auto k_with_bias = mm->add_instruction(migraphx::make_op("add"), k_input, k_bias);
60+
auto v_with_bias = mm->add_instruction(migraphx::make_op("add"), v_input, v_bias);
61+
62+
// Scale Q (common in attention mechanisms)
63+
scale = mm->add_instruction(
64+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scale);
65+
auto q_scaled = mm->add_instruction(migraphx::make_op("mul"), q_with_bias, scale);
66+
67+
// Apply activation (optional input fusion)
68+
auto q_activated = mm->add_instruction(migraphx::make_op("tanh"), q_scaled);
69+
auto k_activated = mm->add_instruction(migraphx::make_op("tanh"), k_with_bias);
70+
auto v_activated = mm->add_instruction(migraphx::make_op("tanh"), v_with_bias);
71+
72+
// Now perform the attention mechanism with fused inputs
73+
// Transpose K and V for matrix multiplication
74+
auto k_transposed = mm->add_instruction(
75+
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k_activated);
76+
auto v_transposed = mm->add_instruction(
77+
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v_activated);
78+
79+
// Compute attention scores: Q @ K^T
80+
auto scores = mm->add_instruction(migraphx::make_op("dot"), q_activated, k_transposed);
81+
82+
// Apply softmax
83+
auto scores_max =
84+
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores);
85+
scores_max = mm->add_instruction(
86+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max);
87+
auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max);
88+
auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub);
89+
auto scores_sum =
90+
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp);
91+
scores_sum = mm->add_instruction(
92+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum);
93+
auto attention_weights =
94+
mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum);
95+
96+
// Apply attention weights to values: attention_weights @ V^T
97+
auto output =
98+
mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed);
99+
100+
mm->add_return({output});
101+
return p1;
102+
}
103+
};
104+
105+
// These tests are not run by default currently; the env vars below need to be set:
106+
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
107+
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
108+
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::half_type>;
109+
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::bf16_type>;
110+
template struct test_attention_flash_decoding_3d_input_fusion<migraphx::shape::float_type>;
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include "verify_program.hpp"
26+
#include <migraphx/program.hpp>
27+
#include <migraphx/generate.hpp>
28+
#include <migraphx/make_op.hpp>
29+
30+
template <migraphx::shape::type_t DType>
31+
struct test_attention_flash_decoding_3d_output_fusion
32+
: verify_program<test_attention_flash_decoding_3d_output_fusion<DType>>
33+
{
34+
migraphx::program create_program() const
35+
{
36+
// 3D Shape: [batch, sequence_length, head_dim]
37+
migraphx::shape s_3d{DType, {1, 256, 256}};
38+
39+
migraphx::program p1;
40+
auto* mm = p1.get_main_module();
41+
42+
// Input parameters for attention
43+
auto q = mm->add_parameter("q", s_3d);
44+
auto k = mm->add_parameter("k", s_3d);
45+
auto v = mm->add_parameter("v", s_3d);
46+
47+
// Parameters for output fusion
48+
// Output projection weight matrix
49+
migraphx::shape proj_weight_shape{DType, {256, 256}};
50+
auto output_proj_weight = mm->add_parameter("output_proj_weight", proj_weight_shape);
51+
52+
// Output bias
53+
migraphx::shape bias_shape{DType, {256}};
54+
auto output_bias = mm->add_parameter("output_bias", bias_shape);
55+
56+
// Residual input for skip connection
57+
auto residual = mm->add_parameter("residual", s_3d);
58+
59+
// Layer norm parameters (gamma and beta)
60+
auto ln_gamma = mm->add_parameter("ln_gamma", bias_shape);
61+
auto ln_beta = mm->add_parameter("ln_beta", bias_shape);
62+
63+
// Gate for gated output
64+
auto output_gate = mm->add_parameter("output_gate", s_3d);
65+
66+
// Standard attention mechanism (no input fusion)
67+
// Transpose K and V for matrix multiplication
68+
auto k_transposed =
69+
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), k);
70+
auto v_transposed =
71+
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), v);
72+
73+
// Compute attention scores: Q @ K^T
74+
auto scores = mm->add_instruction(migraphx::make_op("dot"), q, k_transposed);
75+
76+
// Apply softmax
77+
auto scores_max =
78+
mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), scores);
79+
scores_max = mm->add_instruction(
80+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_max);
81+
auto scores_sub = mm->add_instruction(migraphx::make_op("sub"), scores, scores_max);
82+
auto scores_exp = mm->add_instruction(migraphx::make_op("exp"), scores_sub);
83+
auto scores_sum =
84+
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), scores_exp);
85+
scores_sum = mm->add_instruction(
86+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), scores_sum);
87+
auto attention_weights =
88+
mm->add_instruction(migraphx::make_op("div"), scores_exp, scores_sum);
89+
90+
// Apply attention weights to values: attention_weights @ V^T
91+
auto attention_output =
92+
mm->add_instruction(migraphx::make_op("dot"), attention_weights, v_transposed);
93+
94+
// OUTPUT FUSION OPERATIONS START HERE
95+
96+
// 1. Output projection (linear transformation)
97+
// Reshape for matrix multiplication with projection weight
98+
auto attn_reshaped = mm->add_instruction(
99+
migraphx::make_op("reshape", {{"dims", {256, 256}}}), attention_output);
100+
auto projected =
101+
mm->add_instruction(migraphx::make_op("dot"), attn_reshaped, output_proj_weight);
102+
projected =
103+
mm->add_instruction(migraphx::make_op("reshape", {{"dims", s_3d.lens()}}), projected);
104+
105+
// 2. Add output bias
106+
output_bias = mm->add_instruction(
107+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), output_bias);
108+
auto with_bias = mm->add_instruction(migraphx::make_op("add"), projected, output_bias);
109+
110+
// 3. Apply dropout-like operation (using a gate for deterministic testing)
111+
auto gate_sigmoid = mm->add_instruction(migraphx::make_op("sigmoid"), output_gate);
112+
auto gated = mm->add_instruction(migraphx::make_op("mul"), with_bias, gate_sigmoid);
113+
114+
// 4. Add residual connection
115+
auto with_residual = mm->add_instruction(migraphx::make_op("add"), gated, residual);
116+
117+
// 5. Layer normalization
118+
// Compute mean
119+
auto mean =
120+
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), with_residual);
121+
mean = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}),
122+
mean);
123+
auto centered = mm->add_instruction(migraphx::make_op("sub"), with_residual, mean);
124+
125+
// Compute variance
126+
auto squared = mm->add_instruction(migraphx::make_op("mul"), centered, centered);
127+
auto variance =
128+
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {2}}}), squared);
129+
variance = mm->add_instruction(
130+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), variance);
131+
132+
// Add epsilon for numerical stability
133+
migraphx::shape epsilon_shape{DType, {1}};
134+
auto epsilon = mm->add_literal(migraphx::literal{epsilon_shape, {1e-5f}});
135+
epsilon = mm->add_instruction(
136+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), epsilon);
137+
auto var_plus_eps = mm->add_instruction(migraphx::make_op("add"), variance, epsilon);
138+
139+
// Compute standard deviation
140+
auto std_dev = mm->add_instruction(migraphx::make_op("sqrt"), var_plus_eps);
141+
142+
// Normalize
143+
auto normalized = mm->add_instruction(migraphx::make_op("div"), centered, std_dev);
144+
145+
// Scale and shift
146+
ln_gamma = mm->add_instruction(
147+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_gamma);
148+
ln_beta = mm->add_instruction(
149+
migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), ln_beta);
150+
auto scaled = mm->add_instruction(migraphx::make_op("mul"), normalized, ln_gamma);
151+
auto ln_output = mm->add_instruction(migraphx::make_op("add"), scaled, ln_beta);
152+
153+
// 6. Final activation (ReLU)
154+
auto final_output = mm->add_instruction(migraphx::make_op("relu"), ln_output);
155+
156+
mm->add_return({final_output});
157+
return p1;
158+
}
159+
};
160+
161+
// These tests are not run by default currently; the env vars below need to be set:
162+
// MIGRAPHX_FLASH_DECODING_NUM_SPLITS=2 # or another split factor
163+
// MIGRAPHX_MLIR_USE_SPECIFIC_OPS=attention
164+
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::half_type>;
165+
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::bf16_type>;
166+
template struct test_attention_flash_decoding_3d_output_fusion<migraphx::shape::float_type>;

0 commit comments

Comments
 (0)