From 4afcabd1657dd52231eb99eac34b1714e14cc86a Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 25 Sep 2025 11:15:32 -0500 Subject: [PATCH 01/43] Fix reshapes propogation for simplify_qdq and constant propagation --- src/propagate_constant.cpp | 2 +- src/simplify_qdq.cpp | 52 ++++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/propagate_constant.cpp b/src/propagate_constant.cpp index 50e065520a9..8b44064fa2f 100644 --- a/src/propagate_constant.cpp +++ b/src/propagate_constant.cpp @@ -40,7 +40,7 @@ static bool skip_propagate(instruction_ref ins) { if(contains({"contiguous", "dequantizelinear", "reshape"}, ins->name())) return skip_propagate(ins->inputs().front()); - if(ins->name() == "unpack_int4") + if(contains({"unpack_int4", "unpack_fp4"}, ins->name())) return true; auto&& s = ins->get_shape(); if(s.broadcasted() and s.element_space() < s.elements()) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index f4569fe7a36..1c3580e7ebf 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -49,32 +49,36 @@ std::unordered_set get_quantizable_op_names() return s; } -// Helper function to insert quantized versions of any broadcasts and transpose ops that -// occur between dequantizelinear and the quantized op -auto propagate_quantized_ins(module& m, - const instruction_ref dqins, - const instruction_ref qop_arg, - bool is_fp16_model = false) +std::vector get_inbetween_ins(const instruction_ref dqins, + const instruction_ref qop_arg) { auto prev_ins = qop_arg; std::vector ins_between; - // matcher skips continguous, multi/broadcasts and transposes, collect all those - // instructions while(prev_ins != dqins) { ins_between.push_back(prev_ins); prev_ins = prev_ins->inputs().front(); } - auto qinp = dqins->inputs().front(); + return ins_between; +} + +// Helper function to insert quantized versions of any broadcasts and transpose ops that +// occur between dequantizelinear and the quantized op +auto propagate_quantized_ins(module& m, + const instruction_ref dqins, + instruction_ref input_ins, + std::vector ins_between, + bool is_fp16_model = false) +{ for(auto ins : reverse_iterator_for(ins_between)) { if((*ins)->name() == "convert" and is_fp16_model) { continue; } - qinp = m.insert_instruction(dqins, (*ins)->get_operator(), {qinp}); + input_ins = m.insert_instruction(dqins, (*ins)->get_operator(), {input_ins}); } - return qinp; + return input_ins; } struct match_find_quantizable_ops @@ -140,8 +144,13 @@ struct match_find_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); + + auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]); + auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, qop_args[0], qop_between_arg0, is_fp16_model); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, qop_args[1], qop_between_arg1, is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); @@ -288,7 +297,6 @@ inline auto dynamic_block_dq(const std::string& scale) return match::name("dequantizelinear")( match::nargs(2), match::arg(1)(match::skip_broadcasts(match::none_of( - match::is_constant(), match::scalar_shape, match::ndim(1) ).bind(scale)))); @@ -307,7 +315,7 @@ struct match_find_mx_quantizable_ops { auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1"))); auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2"))); - return match::name("dot")(dq1, dq2); + return match::name(get_quantizable_op_names())(dq1, dq2); } void apply(module& m, const match::matcher_result& r) const @@ -328,10 +336,16 @@ struct match_find_mx_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - qop_args.at(0) = propagate_quantized_ins(m, dq1, qop_args[0], is_fp16_model); - qop_args.at(1) = propagate_quantized_ins(m, dq2, qop_args[1], is_fp16_model); - qop_args.push_back(scale1); - qop_args.push_back(scale2); + auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]); + qop_args.at(0) = + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); + auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]); + qop_args.at(1) = + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); + qop_args.push_back( + propagate_quantized_ins(m, dq1, scale1, qop_between_arg0, is_fp16_model)); + qop_args.push_back( + propagate_quantized_ins(m, dq2, scale2, qop_between_arg1, is_fp16_model)); if(qop->name() == "convolution") { From 9cffdf056f22763c3f917346b06d84db97e4a29c Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 25 Sep 2025 13:31:09 -0500 Subject: [PATCH 02/43] Fix more bugs and make tests --- src/include/migraphx/raw_data.hpp | 20 +++---- src/simplify_qdq.cpp | 7 +-- test/propagate_constant_test.cpp | 27 +++++++++ test/simplify_qdq_test.cpp | 98 +++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 14 deletions(-) diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index 63f512a4948..cecbc9cb7c0 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -53,15 +53,15 @@ struct raw_data : raw_data_base friend Stream& operator<<(Stream& os, const Derived& d) { if(not d.empty()) - d.visit([&](auto x) { os << x; }, - [&](auto&& xs) { - for(auto&& x : xs) - { - os << "{ "; - os << x; - os << " }, "; - } - }); + d.fallback_visit([&](auto x) { os << x; }, + [&](auto&& xs) { + for(auto&& x : xs) + { + os << "{ "; + os << x; + os << " }, "; + } + }); return os; } @@ -125,7 +125,7 @@ struct raw_data : raw_data_base { auto&& buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - v(make_view(view_shape, reinterpret_cast(buffer))); + v(make_view(view_shape, const_cast(reinterpret_cast(buffer)))); } } diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 1c3580e7ebf..944c5a2c724 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -289,9 +289,8 @@ struct match_find_quantizable_ops } }; -// Note: scales are not constant b/c of dynamic quantization. // Checks for block quantized scales by checking scales are not scalar or 1D. -inline auto dynamic_block_dq(const std::string& scale) +inline auto block_dq(const std::string& scale) { // clang-format off return match::name("dequantizelinear")( @@ -313,8 +312,8 @@ struct match_find_mx_quantizable_ops { auto matcher() const { - auto dq1 = match::arg(0)(skip_post_dq_ops(dynamic_block_dq("scale1").bind("dq1"))); - auto dq2 = match::arg(1)(skip_post_dq_ops(dynamic_block_dq("scale2").bind("dq2"))); + auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1"))); + auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2"))); return match::name(get_quantizable_op_names())(dq1, dq2); } diff --git a/test/propagate_constant_test.cpp b/test/propagate_constant_test.cpp index da7511e3482..72011fd8d83 100644 --- a/test/propagate_constant_test.cpp +++ b/test/propagate_constant_test.cpp @@ -535,4 +535,31 @@ TEST_CASE(block_dequantize_int4) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(pack_unpack_fp4) +{ + migraphx::shape s1{migraphx::shape::float_type, {4}}; + migraphx::shape s2{migraphx::shape::fp4x2_type, {2}}; + migraphx::module m1; + { + const std::vector vec = {1.f, 0.f, 2.f, 0.f}; + auto l = m1.add_literal(migraphx::literal(s1, vec)); + auto pack = m1.add_instruction(migraphx::make_op("pack_fp4"), l); + auto unpack = m1.add_instruction(migraphx::make_op("unpack_fp4"), pack); + m1.add_return({unpack}); + } + + run_pass(m1); + + migraphx::module m2; + { + using migraphx::shape; + const std::vector vec = {0x2, 0x4}; + auto l = m2.add_literal(migraphx::literal(s2, vec.data())); + auto unpack = m2.add_instruction(migraphx::make_op("unpack_fp4"), l); + m2.add_return({unpack}); + } + + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index bfc3d61386b..9d82a96eab2 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1815,6 +1815,104 @@ TEST_CASE(fp4x2_quant_dot_even) EXPECT(m1 == m2); } +TEST_CASE(fp4x2_quant_dot_transB) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 8, 24}}; + + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + auto packed_b = m1.add_parameter("weights", shape_packed_b); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto trans_b = m1.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), dq_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, trans_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_parameter("weights", shape_packed_b); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_parameter("scale_b", shape_scales_b); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto trans_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b); + auto trans_scale_b = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), scale_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, trans_b, scale_a, trans_scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(fp4x2_quant_dot_const_B) +{ + migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; + migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}}; + migraphx::shape shape_packed_b_gen{migraphx::shape::uint8_type, {1, 3, 24, 4}}; + migraphx::shape shape_scales_a{migraphx::shape::float_type, {1, 3, 6, 24}}; + migraphx::shape shape_scales_b{migraphx::shape::float_type, {1, 3, 24, 8}}; + unsigned long seed = 826; + migraphx::literal b_lit = generate_literal(shape_packed_b_gen, seed); + migraphx::literal scale_b_lit = generate_literal(shape_scales_b, seed); + migraphx::module m1; + { + auto packed_a = m1.add_parameter("input", shape_packed_a); + // avoiding visit fp4x2_type + auto packed_b = m1.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m1.add_parameter("scale_a", shape_scales_a); + auto scale_b = m1.add_literal(scale_b_lit); + + auto unpack_a = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); + auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); + auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto packed_a = m2.add_parameter("input", shape_packed_a); + auto packed_b = m2.add_literal(shape_packed_b, b_lit.data()); + auto scale_a = m2.add_parameter("scale_a", shape_scales_a); + auto scale_b = m2.add_literal(scale_b_lit); + + auto unpack_a = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); + auto unpack_b = + m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto quant_dot = m2.add_instruction( + migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); + m2.add_return({quant_dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + // Test that unused qdq with pack_fp4, unpack_fp4 are removed TEST_CASE(fp4x2_even_remove_qdq) { From 4d57e7f0a85368e4c9cb950305059b109223d501 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 25 Sep 2025 16:59:40 -0500 Subject: [PATCH 03/43] Fix reinterpret_cast of the possible r-value reference pointer to const char --- src/include/migraphx/raw_data.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index cecbc9cb7c0..e58a1977c5c 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -125,7 +125,11 @@ struct raw_data : raw_data_base { auto&& buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - v(make_view(view_shape, const_cast(reinterpret_cast(buffer)))); + using byte_type = std::conditional_t< + std::is_const_v>>, + const byte*, + byte*>; + v(make_view(view_shape, reinterpret_cast(buffer))); } } From 3fa8a1ba3a11d7f227aa226152c67e569bcf574f Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 26 Sep 2025 11:16:45 -0500 Subject: [PATCH 04/43] Fix introduced bug --- src/simplify_qdq.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 944c5a2c724..4cb829bfdaa 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -49,8 +49,8 @@ std::unordered_set get_quantizable_op_names() return s; } -std::vector get_inbetween_ins(const instruction_ref dqins, - const instruction_ref qop_arg) +std::vector get_between_ins(const instruction_ref dqins, + const instruction_ref qop_arg) { auto prev_ins = qop_arg; std::vector ins_between; @@ -145,12 +145,12 @@ struct match_find_quantizable_ops is_fp16_model = true; } - auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]); - auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]); + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); qop_args.at(0) = - propagate_quantized_ins(m, dq1, qop_args[0], qop_between_arg0, is_fp16_model); + propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); qop_args.at(1) = - propagate_quantized_ins(m, dq2, qop_args[1], qop_between_arg1, is_fp16_model); + propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); auto arg1_lens = qop_args[0]->get_shape().lens(); auto arg2_lens = qop_args[1]->get_shape().lens(); @@ -335,10 +335,10 @@ struct match_find_mx_quantizable_ops assert(dq1->get_shape().type() == migraphx::shape::float_type); is_fp16_model = true; } - auto qop_between_arg0 = get_inbetween_ins(dq1, qop_args[0]); + auto qop_between_arg0 = get_between_ins(dq1, qop_args[0]); qop_args.at(0) = propagate_quantized_ins(m, dq1, dq1->inputs().front(), qop_between_arg0, is_fp16_model); - auto qop_between_arg1 = get_inbetween_ins(dq2, qop_args[1]); + auto qop_between_arg1 = get_between_ins(dq2, qop_args[1]); qop_args.at(1) = propagate_quantized_ins(m, dq2, dq2->inputs().front(), qop_between_arg1, is_fp16_model); qop_args.push_back( From 7b11af379820b9f9c2824d0d53410453c956a9db Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 26 Sep 2025 11:25:53 -0500 Subject: [PATCH 05/43] Tidy fix --- test/simplify_qdq_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 9d82a96eab2..8d548224762 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1815,7 +1815,7 @@ TEST_CASE(fp4x2_quant_dot_even) EXPECT(m1 == m2); } -TEST_CASE(fp4x2_quant_dot_transB) +TEST_CASE(fp4x2_quant_dot_trans_b) { migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 8, 12}}; @@ -1865,7 +1865,7 @@ TEST_CASE(fp4x2_quant_dot_transB) EXPECT(m1 == m2); } -TEST_CASE(fp4x2_quant_dot_const_B) +TEST_CASE(fp4x2_quant_dot_const_b) { migraphx::shape shape_packed_a{migraphx::shape::fp4x2_type, {1, 3, 6, 12}}; migraphx::shape shape_packed_b{migraphx::shape::fp4x2_type, {1, 3, 24, 4}}; From 335439b5a7f4f0b719fa9f092ea7df0ad8f1b51e Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 26 Sep 2025 16:50:13 -0500 Subject: [PATCH 06/43] initial --- src/targets/gpu/mlir.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 19e0ae70b44..bb9014cc583 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -332,6 +332,8 @@ struct mlir_program result = mlirFloat8E5M2TypeGet(ctx.get()); else if(as.type_enum() == shape::double_type) result = mlirF64TypeGet(ctx.get()); + else if(as.type_enum() == shape::fp4x2_type) + result = mlirFloat8E4M3FNTypeGet(ctx.get()); else if(as.is_integral()) { if(as.is_unsigned()) @@ -647,6 +649,8 @@ struct mlir_program return "migraphx.literal"; if(ins->name() == "unpack_int4") return "migraphx.unpack"; + if(ins->name() == "unpack_fp4") + return "migraphx.unpack"; if(ins->name() == "convolution_backwards") return "migraphx.backwards_data_convolution"; if(is_reshape(ins->name())) From 5e04de70a2a2f1690514cbc03b497bf2375b08fb Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 11:46:23 -0500 Subject: [PATCH 07/43] reviews code style --- src/include/migraphx/raw_data.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/raw_data.hpp b/src/include/migraphx/raw_data.hpp index e58a1977c5c..ebd72acaab6 100644 --- a/src/include/migraphx/raw_data.hpp +++ b/src/include/migraphx/raw_data.hpp @@ -123,12 +123,12 @@ struct raw_data : raw_data_base } else { - auto&& buffer = static_cast(*this).data(); + auto* buffer = static_cast(*this).data(); shape view_shape = {shape::uint8_type, {s.bytes()}}; - using byte_type = std::conditional_t< - std::is_const_v>>, - const byte*, - byte*>; + using byte_type = + std::conditional_t>{}, + const byte*, + byte*>; v(make_view(view_shape, reinterpret_cast(buffer))); } } From a29f4d93d527385f980d3bfeb96e2a0542a3103a Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 13:39:28 -0500 Subject: [PATCH 08/43] Update rocMLIR and rocm --- Dockerfile | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 7aecb0c685e..67ffb7d3956 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y software-properties-common gnupg2 --no- curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - # Add rocm repository -RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/6.4.2/ jammy main > /etc/apt/sources.list.d/rocm.list' +RUN sh -c 'echo deb [arch=amd64 trusted=yes] http://repo.radeon.com/rocm/apt/7.0.1/ jammy main > /etc/apt/sources.list.d/rocm.list' # From docs.amd.com for installing rocm. Needed to install properly RUN sh -c "echo 'Package: *\nPin: release o=repo.radeon.com\nPin-priority: 600' > /etc/apt/preferences.d/rocm-pin-600" diff --git a/requirements.txt b/requirements.txt index 708b53287f1..8a021faaa90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@265d98be2fd97b90b9b1c2da60b0f79a745059be -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@d2d5033dd9a99da98a53e69b6067e15ab3048d1d -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From 473c85ba0379926307f66c63b980781f00c08cd2 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 14:00:02 -0500 Subject: [PATCH 09/43] Avoid visit of fp4x2_type --- src/targets/gpu/mlir.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index bb9014cc583..7b833acf5a3 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -315,6 +315,11 @@ struct mlir_program MlirType make_type(shape::type_t t) const { MlirType result; + // non-computable type is not visit-able + if(t == shape::fp4x2_type) + { + result = mlirFloat8E4M3FNTypeGet(ctx.get()); + } shape::visit(t, [&](auto as) { if(as.type_enum() == shape::float_type) result = mlirF32TypeGet(ctx.get()); @@ -332,8 +337,6 @@ struct mlir_program result = mlirFloat8E5M2TypeGet(ctx.get()); else if(as.type_enum() == shape::double_type) result = mlirF64TypeGet(ctx.get()); - else if(as.type_enum() == shape::fp4x2_type) - result = mlirFloat8E4M3FNTypeGet(ctx.get()); else if(as.is_integral()) { if(as.is_unsigned()) From 866dbf7da6e14998734bd9de4af198bb6bec29ad Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 14:11:11 -0500 Subject: [PATCH 10/43] typo fix --- src/targets/gpu/mlir.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 7b833acf5a3..b292ad9336b 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -314,12 +314,12 @@ struct mlir_program MlirType make_type(shape::type_t t) const { - MlirType result; // non-computable type is not visit-able if(t == shape::fp4x2_type) { - result = mlirFloat8E4M3FNTypeGet(ctx.get()); + return mlirFloat8E4M3FNTypeGet(ctx.get()); } + MlirType result; shape::visit(t, [&](auto as) { if(as.type_enum() == shape::float_type) result = mlirF32TypeGet(ctx.get()); From c42d896ce21f9f358481e88a24ee0ea4eb9df9bd Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 14:36:43 -0500 Subject: [PATCH 11/43] Remove quant_conv again --- src/simplify_qdq.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 4cb829bfdaa..4169964d9d0 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -314,7 +314,7 @@ struct match_find_mx_quantizable_ops { auto dq1 = match::arg(0)(skip_post_dq_ops(block_dq("scale1").bind("dq1"))); auto dq2 = match::arg(1)(skip_post_dq_ops(block_dq("scale2").bind("dq2"))); - return match::name(get_quantizable_op_names())(dq1, dq2); + return match::name("dot")(dq1, dq2); } void apply(module& m, const match::matcher_result& r) const From 5eacd8760219cd0d6ca7a80ea42aa1963c6fdcc6 Mon Sep 17 00:00:00 2001 From: charlie Date: Mon, 29 Sep 2025 16:52:12 -0500 Subject: [PATCH 12/43] Update requirements rocmlir to latest branch --- requirements.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8a021faaa90..a0d9a25cd81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -##################################################################################### +# +################################################################################### # The MIT License (MIT) # # Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. @@ -29,4 +30,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@d2d5033dd9a99da98a53e69b6067e15ab3048d1d -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@1e65672899f431b6da1285e5aecfe04a159b1932 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From 2ecd0e30333249a95c5988ae54ac2712ab7e6e9b Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 30 Sep 2025 16:11:01 -0500 Subject: [PATCH 13/43] Add Umang's changes --- src/targets/gpu/mlir.cpp | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index b292ad9336b..a2a36ded85f 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,6 +508,17 @@ struct mlir_program { } + void setOperandSegmentSizes(int numSegments, const int32_t* sizes) + { + MlirAttribute segmentSizesAttr = + mlirDenseI32ArrayGet(prog->ctx.get(), numSegments, sizes); + MlirNamedAttribute namedAttr = mlirNamedAttributeGet( + mlirIdentifierGet(prog->ctx.get(), + mlirStringRefCreateFromCString("operandSegmentSizes")), + segmentSizesAttr); + mlirOperationStateAddAttributes(&op_state, 1, &namedAttr); + } + mlir_operation_state& add_attributes(const std::vector& named_attrs) { auto attributes = prog->name_attributes(named_attrs); @@ -755,9 +766,20 @@ struct mlir_program std::vector inputs; transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); + + if(ins->name() == "quant_dot" && + ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type && + ins->inputs().size() == 4) + { + // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. + // Use the canonical MLIR attribute name 'operandSegmentSizes'. + int32_t seg_sizes[] = {1, 1, 1, 1}; + ops.setOperandSegmentSizes(4, seg_sizes); + } ops.add_operands(inputs); auto outputs = insert(fbody, std::move(ops)); + if(ins->name() != "@return") { assert(outputs.size() == 1); @@ -1208,6 +1230,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx, const std::lock_guard lock(mutex); std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; } + auto co = mp.compile(solution); co.expected_inputs = in_shapes; From a91f16b4183ec745617f2bbbd333b9b4012b4533 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 10 Oct 2025 16:49:15 -0500 Subject: [PATCH 14/43] In Progress --- src/driver/verify.cpp | 14 +++++- src/include/migraphx/simplify_qdq.hpp | 1 + src/simplify_qdq.cpp | 33 ++++++++----- src/targets/ref/target.cpp | 3 ++ test/verify/test_mxfp4_gemm.cpp | 68 +++++++++++++++++++++++++++ 5 files changed, 104 insertions(+), 15 deletions(-) create mode 100644 test/verify/test_mxfp4_gemm.cpp diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index c47322f83da..009b454cb45 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -43,7 +43,8 @@ inline namespace MIGRAPHX_INLINE_NS { /** * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. - * Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the + * Sets to fp4 tolerances if any fp4x2_type is found. + * Else sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the * model. */ verify::tolerance get_tolerances(const program& p, @@ -58,8 +59,17 @@ verify::tolerance get_tolerances(const program& p, ins.get_shape().type() == shape::bf16_type); }); }); + bool has_fp4 = any_of(p.get_modules(), [](auto&& m) { + return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::fp4x2_type); }); + }); migraphx::verify::tolerance result{}; - if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16) + if(has_fp4) + { + result.rms_tol = 8e-1; + result.atol = 4e-1; + result.rtol = 4e-1; + } + else if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16) { result.rms_tol = 8e-2; result.atol = 4e-2; diff --git a/src/include/migraphx/simplify_qdq.hpp b/src/include/migraphx/simplify_qdq.hpp index 7be4efc7cb2..18892688221 100644 --- a/src/include/migraphx/simplify_qdq.hpp +++ b/src/include/migraphx/simplify_qdq.hpp @@ -38,6 +38,7 @@ struct module; */ struct MIGRAPHX_EXPORT simplify_qdq { + bool remove_qdq_only = false; std::string name() const { return "simplify_qdq"; } void apply(module& m) const; }; diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 4169964d9d0..d93c6b79b4d 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -618,19 +618,26 @@ void add_int4_pack_unpack_pair(module& m) void simplify_qdq::apply(module& m) const { - // first step: add pack/unpack pair between qdq for int4 weights - add_int4_pack_unpack_pair(m); - match::find_matches(m, match_find_quantizable_ops{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, match_find_mx_quantizable_ops{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, remove_qdq_pairs{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, match_qlinear_reused{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, match_concat_qlinear{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - remove_zero_point(m); + if(remove_qdq_only) + { + match::find_matches(m, remove_qdq_pairs{}); + } + else + { + // first step: add pack/unpack pair between qdq for int4 weights + add_int4_pack_unpack_pair(m); + match::find_matches(m, match_find_quantizable_ops{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, match_find_mx_quantizable_ops{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, remove_qdq_pairs{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, match_qlinear_reused{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, match_concat_qlinear{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + remove_zero_point(m); + } } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/ref/target.cpp b/src/targets/ref/target.cpp index 13c15e541e3..5dde5027c47 100644 --- a/src/targets/ref/target.cpp +++ b/src/targets/ref/target.cpp @@ -35,6 +35,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -51,6 +52,8 @@ std::vector target::get_passes(migraphx::context&, const compile_options&) dead_code_elimination{}, rewrite_rnn{}, dead_code_elimination{}, + simplify_qdq{.remove_qdq_only = true}, + dead_code_elimination{}, auto_contiguous{}, dead_code_elimination{}, lowering{}, diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp new file mode 100644 index 00000000000..ca1f1af100d --- /dev/null +++ b/test/verify/test_mxfp4_gemm.cpp @@ -0,0 +1,68 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +instruction_ref add_dyn_scale_calc(instruction_ref input) +{ + // TODO +} + +struct test_mxfp4_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + migraphx::module_ref mmain = p.get_main_module(); + // TODO these scale literals need to be E8M0 values + auto x_0 = mmain->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); + auto x_1 = mmain->add_literal(migraphx::abs(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {64, 1, 1000}, {1, 1, 64}}, 1))); + auto x_2 = mmain->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::fp4x2_type, {1000, 1024}}, 2)); + auto p_x3 = + mmain->add_parameter("x3", migraphx::shape{migraphx::shape::float_type, {1, 64, 1}}); + auto p_x1 = + mmain->add_parameter("x1", migraphx::shape{migraphx::shape::fp4x2_type, {1, 1024}}); + auto x_5 = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), p_x1); + auto x_6 = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), x_2); + auto x_7 = + mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x_6); + auto x_8 = mmain->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 32}}}), p_x3); + auto x_9 = mmain->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2048}}}), x_8); + auto x_10 = mmain->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 32, 1000}}}), x_1); + auto x_11 = + mmain->add_instruction(migraphx::make_op("reshape", {{"dims", {2048, 1000}}}), x_10); + auto x_12 = mmain->add_instruction(migraphx::make_op("quant_dot"), x_5, x_7, x_9, x_11); + auto x_13 = mmain->add_instruction(migraphx::make_op("add"), x_12, x_0); + mmain->add_return({x_13}); + } + std::string section() const { return "gemm"; } +}; From 0f265328a8eeb9f41b53f29e1feb9f62709d7e75 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 10:03:29 -0500 Subject: [PATCH 15/43] tidy up --- requirements.txt | 1 - src/targets/gpu/mlir.cpp | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index a48ac078824..da745695eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -# ################################################################################### # The MIT License (MIT) # diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index a2a36ded85f..89c50f1d116 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,15 +508,15 @@ struct mlir_program { } - void setOperandSegmentSizes(int numSegments, const int32_t* sizes) + void set_operand_segement_sizes(int num_segments, const int32_t* sizes) { - MlirAttribute segmentSizesAttr = - mlirDenseI32ArrayGet(prog->ctx.get(), numSegments, sizes); - MlirNamedAttribute namedAttr = mlirNamedAttributeGet( + MlirAttribute segment_sizes_attr = + mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes); + MlirNamedAttribute named_attr = mlirNamedAttributeGet( mlirIdentifierGet(prog->ctx.get(), mlirStringRefCreateFromCString("operandSegmentSizes")), - segmentSizesAttr); - mlirOperationStateAddAttributes(&op_state, 1, &namedAttr); + segment_sizes_attr); + mlirOperationStateAddAttributes(&op_state, 1, &named_attr); } mlir_operation_state& add_attributes(const std::vector& named_attrs) @@ -774,7 +774,7 @@ struct mlir_program // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. int32_t seg_sizes[] = {1, 1, 1, 1}; - ops.setOperandSegmentSizes(4, seg_sizes); + ops.set_operand_segement_sizes(4, seg_sizes); } ops.add_operands(inputs); From d05b8f8339d8aa448bdfb6855546cea7d948d082 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 12:30:21 -0500 Subject: [PATCH 16/43] Enable simplify_qdq for unpack_fp4 only if >=MI350 --- src/include/migraphx/simplify_qdq.hpp | 1 + src/simplify_qdq.cpp | 7 +++++-- src/targets/gpu/device_name.cpp | 5 +++++ src/targets/gpu/target.cpp | 2 +- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/simplify_qdq.hpp b/src/include/migraphx/simplify_qdq.hpp index 18892688221..16d32534792 100644 --- a/src/include/migraphx/simplify_qdq.hpp +++ b/src/include/migraphx/simplify_qdq.hpp @@ -39,6 +39,7 @@ struct module; struct MIGRAPHX_EXPORT simplify_qdq { bool remove_qdq_only = false; + bool use_mx_quant = false; std::string name() const { return "simplify_qdq"; } void apply(module& m) const; }; diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index d93c6b79b4d..973e592a3e6 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -628,8 +628,11 @@ void simplify_qdq::apply(module& m) const add_int4_pack_unpack_pair(m); match::find_matches(m, match_find_quantizable_ops{}); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); - match::find_matches(m, match_find_mx_quantizable_ops{}); - migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + if(use_mx_quant) + { + match::find_matches(m, match_find_mx_quantizable_ops{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + } match::find_matches(m, remove_qdq_pairs{}); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); match::find_matches(m, match_qlinear_reused{}); diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp index ae86806c38d..7f98a70a8f5 100644 --- a/src/targets/gpu/device_name.cpp +++ b/src/targets/gpu/device_name.cpp @@ -75,6 +75,11 @@ bool gfx_has_bf16_intrinsics() return not(starts_with(device_name, "gfx1030")); } +bool gfx_has_mx_intrinsics() +{ + return starts_with(device_name, "gfx9") and device_name >= "gfx950"; +} + #if MIGRAPHX_USE_HIPBLASLT // Archs that support hipBLASLt but are defaulted to use rocBLAS. bool gfx_default_rocblas() diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 5844c934259..0fb9376700a 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -189,7 +189,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}), enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}), - simplify_qdq{}, + simplify_qdq{.use_mx_quant=gfx_has_mx_intrinsics()}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, rewrite_rnn{}, From 1299fd45efba9d9c922a0e54d53a3e173e22ff3b Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:30:53 -0500 Subject: [PATCH 17/43] Add mxfp4 quant_dot verify test --- src/targets/gpu/mlir.cpp | 10 ++-- test/verify/test_mxfp4_gemm.cpp | 102 ++++++++++++++++++++++++-------- 2 files changed, 81 insertions(+), 31 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 89c50f1d116..18690c07475 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,10 +508,10 @@ struct mlir_program { } - void set_operand_segement_sizes(int num_segments, const int32_t* sizes) + void set_operand_segement_sizes(int num_segments, const vector& sizes) { MlirAttribute segment_sizes_attr = - mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes); + mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data()); MlirNamedAttribute named_attr = mlirNamedAttributeGet( mlirIdentifierGet(prog->ctx.get(), mlirStringRefCreateFromCString("operandSegmentSizes")), @@ -767,13 +767,13 @@ struct mlir_program transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); - if(ins->name() == "quant_dot" && - ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type && + if(ins->name() == "quant_dot" and + ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type and ins->inputs().size() == 4) { // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. - int32_t seg_sizes[] = {1, 1, 1, 1}; + const vector seg_sizes = {1, 1, 1, 1}; ops.set_operand_segement_sizes(4, seg_sizes); } ops.add_operands(inputs); diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index ca1f1af100d..8a9795adb84 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -27,42 +27,92 @@ #include #include -instruction_ref add_dyn_scale_calc(instruction_ref input) +instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int block_axis, int block_size) { - // TODO + // Code similar to that in parse_mxfixneruon + // make reduction axes for calculating block scales + // tmp_lens != input_lens if runt block is padded + instruction_ref tmp_in = input; + const auto input_lens = input->get_shape().lens(); + auto tmp_lens = input_lens; + auto block_dim = tmp_lens.at(block_axis); + std::size_t block_padding = + std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim; + // handle runt block by padding + if(block_padding != 0) + { + std::vector pads_vec(2 * tmp_lens.size(), 0); + pads_vec.at(block_axis + tmp_lens.size()) = block_padding; + tmp_in = m.add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in); + tmp_lens = tmp_in->get_shape().lens(); + } + // reshape block dimension to {num_blocks, block_size} + std::size_t num_blocks = tmp_lens.at(block_axis) / std::size_t(block_size); + std::vector reduct_dims = tmp_lens; + reduct_dims.at(block_axis) = block_size; + reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks); + instruction_ref reshape_ins = m.add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); + + // dynamic quantization for MX types: + // V_k = fp32 vector input of block size k + // B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V + // X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4 + auto abs_ins = m.add_instruction(make_op("abs"), reshape_ins); + auto reduce_max_ins = + m.add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); + auto log2_ins = m.add_instruction(make_op("log2"), reduce_max_ins); + auto floor_ins = m.add_instruction(make_op("floor"), log2_ins); + auto lit_2_ins = m.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto broadcast_lit_2_ins = m.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_2_ins); + auto pow_ins = m.add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); + auto lit_4_ins = m.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); + auto broadcast_lit_4_ins = m.add_instruction( + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), + lit_4_ins); + auto block_scales_ins = m.add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); + + // broadcast scales for use in quantizelinear + block_scales_ins = m.add_instruction( + make_op("multibroadcast", {{"out_lens", reduct_dims}}), block_scales_ins); + block_scales_ins = + m.add_instruction(make_op("reshape", {{"dims", tmp_lens}}), block_scales_ins); + + // if padded runt block do slicing + if(tmp_lens != input_lens) + { + std::size_t slice_size = input_lens.at(block_axis); + block_scales_ins = m.add_instruction( + make_op("slice", {{"axes", {block_axis}}, {"starts", {0}}, {"ends", {slice_size}}}), + block_scales_ins); + } + return block_scales_ins; } +/** + * Designed to be a quant_dot with MX block scales after the simplify_qdq pass. + */ struct test_mxfp4_gemm : verify_program { migraphx::program create_program() const { migraphx::program p; migraphx::module_ref mmain = p.get_main_module(); - // TODO these scale literals need to be E8M0 values - auto x_0 = mmain->add_literal( - migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); - auto x_1 = mmain->add_literal(migraphx::abs(migraphx::generate_literal( - migraphx::shape{migraphx::shape::float_type, {64, 1, 1000}, {1, 1, 64}}, 1))); - auto x_2 = mmain->add_literal(migraphx::generate_literal( - migraphx::shape{migraphx::shape::fp4x2_type, {1000, 1024}}, 2)); - auto p_x3 = - mmain->add_parameter("x3", migraphx::shape{migraphx::shape::float_type, {1, 64, 1}}); - auto p_x1 = + auto input = mmain->add_parameter("x1", migraphx::shape{migraphx::shape::fp4x2_type, {1, 1024}}); - auto x_5 = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), p_x1); - auto x_6 = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), x_2); - auto x_7 = - mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), x_6); - auto x_8 = mmain->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 32}}}), p_x3); - auto x_9 = mmain->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2048}}}), x_8); - auto x_10 = mmain->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {64, 32, 1000}}}), x_1); - auto x_11 = - mmain->add_instruction(migraphx::make_op("reshape", {{"dims", {2048, 1000}}}), x_10); - auto x_12 = mmain->add_instruction(migraphx::make_op("quant_dot"), x_5, x_7, x_9, x_11); - auto x_13 = mmain->add_instruction(migraphx::make_op("add"), x_12, x_0); - mmain->add_return({x_13}); + auto input_scales = add_dyn_scale_calc(mmain, input, 1, 32); + auto bias = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); + auto weights = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::fp4x2_type, {1000, 1024}}, 2)); + auto weight_scales = add_dyn_scale_calc(mmain, weights, 1, 32); + input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); + weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); + weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights); + auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, x_9, weight_scales); + auto bias_add = mmain->add_instruction(migraphx::make_op("add"), quant_dot, bias); + mmain->add_return({bias_add}); } std::string section() const { return "gemm"; } }; From 77bfb9b3071431b1669eb944b0e1e615e5d1cadf Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:36:04 -0500 Subject: [PATCH 18/43] Fix typos --- requirements.txt | 2 +- src/targets/gpu/device_name.cpp | 1 + src/targets/gpu/mlir.cpp | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index da745695eed..2fc0d11cbd4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -##################################################################################### +################################################################################### abseil/abseil-cpp@20250512.0 -DABSL_ENABLE_INSTALL=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON google/protobuf@v30.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -Dprotobuf_BUILD_TESTS=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 nlohmann/json@v3.8.0 -DCMAKE_POLICY_VERSION_MINIMUM=3.5 diff --git a/src/targets/gpu/device_name.cpp b/src/targets/gpu/device_name.cpp index 7f98a70a8f5..5242bda8caa 100644 --- a/src/targets/gpu/device_name.cpp +++ b/src/targets/gpu/device_name.cpp @@ -77,6 +77,7 @@ bool gfx_has_bf16_intrinsics() bool gfx_has_mx_intrinsics() { + const auto device_name = trim(split_string(get_device_name(), ':').front()); return starts_with(device_name, "gfx9") and device_name >= "gfx950"; } diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 18690c07475..9e744a3e63b 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,7 +508,7 @@ struct mlir_program { } - void set_operand_segement_sizes(int num_segments, const vector& sizes) + void set_operand_segement_sizes(int num_segments, const std::vector& sizes) { MlirAttribute segment_sizes_attr = mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data()); @@ -773,7 +773,7 @@ struct mlir_program { // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. - const vector seg_sizes = {1, 1, 1, 1}; + const std::vector seg_sizes = {1, 1, 1, 1}; ops.set_operand_segement_sizes(4, seg_sizes); } ops.add_operands(inputs); From 1eb70a51186aebaa7c9b3e45278ecc9d225a7609 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:38:12 -0500 Subject: [PATCH 19/43] Add to header --- src/targets/gpu/include/migraphx/gpu/device_name.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/targets/gpu/include/migraphx/gpu/device_name.hpp b/src/targets/gpu/include/migraphx/gpu/device_name.hpp index 19649077747..281688b7ff2 100644 --- a/src/targets/gpu/include/migraphx/gpu/device_name.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device_name.hpp @@ -43,8 +43,11 @@ MIGRAPHX_GPU_EXPORT bool gfx_has_fp8ocp_intrinsics(); MIGRAPHX_GPU_EXPORT bool gfx_has_bf16_intrinsics(); +MIGRAPHX_GPU_EXPORT bool gfx_has_mx_intrinsics(); + MIGRAPHX_GPU_EXPORT bool gfx_has_fp8fnuz_support(); + #if MIGRAPHX_USE_HIPBLASLT MIGRAPHX_GPU_EXPORT bool gfx_default_rocblas(); #endif From 57bd0eb7ab13697d5f1b1a70424b5ad2a0dcf0ce Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:42:40 -0500 Subject: [PATCH 20/43] Use -> and using --- test/verify/test_mxfp4_gemm.cpp | 35 +++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 8a9795adb84..186afb752c4 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -29,6 +29,11 @@ instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int block_axis, int block_size) { + + using migraphx::instruction_ref; + using migraphx::module_ref; + using migraphx::make_op; + // Code similar to that in parse_mxfixneruon // make reduction axes for calculating block scales // tmp_lens != input_lens if runt block is padded @@ -43,7 +48,7 @@ instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int bloc { std::vector pads_vec(2 * tmp_lens.size(), 0); pads_vec.at(block_axis + tmp_lens.size()) = block_padding; - tmp_in = m.add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in); + tmp_in = m->add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in); tmp_lens = tmp_in->get_shape().lens(); } // reshape block dimension to {num_blocks, block_size} @@ -51,41 +56,41 @@ instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int bloc std::vector reduct_dims = tmp_lens; reduct_dims.at(block_axis) = block_size; reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks); - instruction_ref reshape_ins = m.add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); + instruction_ref reshape_ins = m->add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); // dynamic quantization for MX types: // V_k = fp32 vector input of block size k // B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V // X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4 - auto abs_ins = m.add_instruction(make_op("abs"), reshape_ins); + auto abs_ins = m->add_instruction(make_op("abs"), reshape_ins); auto reduce_max_ins = - m.add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); - auto log2_ins = m.add_instruction(make_op("log2"), reduce_max_ins); - auto floor_ins = m.add_instruction(make_op("floor"), log2_ins); - auto lit_2_ins = m.add_literal( + m->add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); + auto log2_ins = m->add_instruction(make_op("log2"), reduce_max_ins); + auto floor_ins = m->add_instruction(make_op("floor"), log2_ins); + auto lit_2_ins = m->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); - auto broadcast_lit_2_ins = m.add_instruction( + auto broadcast_lit_2_ins = m->add_instruction( make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), lit_2_ins); - auto pow_ins = m.add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); - auto lit_4_ins = m.add_literal( + auto pow_ins = m->add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); + auto lit_4_ins = m->add_literal( migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); - auto broadcast_lit_4_ins = m.add_instruction( + auto broadcast_lit_4_ins = m->add_instruction( make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), lit_4_ins); - auto block_scales_ins = m.add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); + auto block_scales_ins = m->add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); // broadcast scales for use in quantizelinear - block_scales_ins = m.add_instruction( + block_scales_ins = m->add_instruction( make_op("multibroadcast", {{"out_lens", reduct_dims}}), block_scales_ins); block_scales_ins = - m.add_instruction(make_op("reshape", {{"dims", tmp_lens}}), block_scales_ins); + m->add_instruction(make_op("reshape", {{"dims", tmp_lens}}), block_scales_ins); // if padded runt block do slicing if(tmp_lens != input_lens) { std::size_t slice_size = input_lens.at(block_axis); - block_scales_ins = m.add_instruction( + block_scales_ins = m->add_instruction( make_op("slice", {{"axes", {block_axis}}, {"starts", {0}}, {"ends", {slice_size}}}), block_scales_ins); } From 46c938cd582d352435646a7cb44e3a6109e376af Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:49:31 -0500 Subject: [PATCH 21/43] More fixes --- test/verify/test_mxfp4_gemm.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 186afb752c4..bf3953b73a4 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -27,9 +27,8 @@ #include #include -instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int block_axis, int block_size) +instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size) { - using migraphx::instruction_ref; using migraphx::module_ref; using migraphx::make_op; @@ -115,7 +114,7 @@ struct test_mxfp4_gemm : verify_program input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights); - auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, x_9, weight_scales); + auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, input_scales, weight_scales); auto bias_add = mmain->add_instruction(migraphx::make_op("add"), quant_dot, bias); mmain->add_return({bias_add}); } From 7b5cd4630bf9ee6ea78599e86f5e08ba59e21cf7 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:49:31 -0500 Subject: [PATCH 22/43] More fixes --- test/verify/test_mxfp4_gemm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 186afb752c4..0215dfd8383 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -26,10 +26,10 @@ #include #include #include +#include -instruction_ref add_dyn_scale_calc(module_ref m, instruction_ref input, int block_axis, int block_size) +instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size) { - using migraphx::instruction_ref; using migraphx::module_ref; using migraphx::make_op; @@ -115,7 +115,7 @@ struct test_mxfp4_gemm : verify_program input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights); - auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, x_9, weight_scales); + auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, input_scales, weight_scales); auto bias_add = mmain->add_instruction(migraphx::make_op("add"), quant_dot, bias); mmain->add_return({bias_add}); } From 8334922e475d89adf137d8e3989971014c36df8e Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:52:06 -0500 Subject: [PATCH 23/43] etc --- test/verify/test_mxfp4_gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 0215dfd8383..fd43b6f9b7a 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -28,7 +28,7 @@ #include #include -instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size) +migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size) { using migraphx::instruction_ref; using migraphx::module_ref; From d58baefbb380a0c819789415cd85d2c49d259791 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 13:53:08 -0500 Subject: [PATCH 24/43] add return --- test/verify/test_mxfp4_gemm.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index fd43b6f9b7a..af9404b6808 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -118,6 +118,7 @@ struct test_mxfp4_gemm : verify_program auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, input_scales, weight_scales); auto bias_add = mmain->add_instruction(migraphx::make_op("add"), quant_dot, bias); mmain->add_return({bias_add}); + return p; } std::string section() const { return "gemm"; } }; From ca0ef2de5f86df0388edf30c2669638b03b1fc36 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 14 Oct 2025 14:37:36 -0500 Subject: [PATCH 25/43] Fix test --- requirements.txt | 4 ++-- test/verify/test_mxfp4_gemm.cpp | 21 +++++++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2fc0d11cbd4..1c24595a91f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -################################################################################### +##################################################################################### # The MIT License (MIT) # # Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. @@ -20,7 +20,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -################################################################################### +##################################################################################### abseil/abseil-cpp@20250512.0 -DABSL_ENABLE_INSTALL=ON -DCMAKE_POSITION_INDEPENDENT_CODE=ON google/protobuf@v30.0 -DCMAKE_POSITION_INDEPENDENT_CODE=On -Dprotobuf_BUILD_TESTS=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 nlohmann/json@v3.8.0 -DCMAKE_POLICY_VERSION_MINIMUM=3.5 diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index af9404b6808..30f84c6e049 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -98,7 +98,7 @@ migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::i } /** - * Designed to be a quant_dot with MX block scales after the simplify_qdq pass. + * Designed to be like the final GEMM of resnet50. */ struct test_mxfp4_gemm : verify_program { @@ -106,17 +106,22 @@ struct test_mxfp4_gemm : verify_program { migraphx::program p; migraphx::module_ref mmain = p.get_main_module(); - auto input = - mmain->add_parameter("x1", migraphx::shape{migraphx::shape::fp4x2_type, {1, 1024}}); + auto input = mmain->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {1, 2048}}); auto input_scales = add_dyn_scale_calc(mmain, input, 1, 32); - auto bias = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); - auto weights = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::fp4x2_type, {1000, 1024}}, 2)); - auto weight_scales = add_dyn_scale_calc(mmain, weights, 1, 32); + input = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), input, input_scales); + input = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), input); input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); + input = mmain->add_instruction(migraphx::make_op("dequantizelinear"), input, input_scales); + auto weights = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 2)); + auto weight_scales = add_dyn_scale_calc(mmain, weights, 1, 32); + weights = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), weights, weight_scales); + weights = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), weights); weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); + weights = mmain->add_instruction(migraphx::make_op("dequantizelinear"), weights, weight_scales); weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights); - auto quant_dot = mmain->add_instruction(migraphx::make_op("quant_dot"), input, weights, input_scales, weight_scales); - auto bias_add = mmain->add_instruction(migraphx::make_op("add"), quant_dot, bias); + auto bias = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); + auto dot = mmain->add_instruction(migraphx::make_op("dot"), input, weights); + auto bias_add = mmain->add_instruction(migraphx::make_op("add"), dot, bias); mmain->add_return({bias_add}); return p; } From 81d4bd838213a8c271d036bd4749b0b3c3fb2191 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Tue, 14 Oct 2025 19:50:19 +0000 Subject: [PATCH 26/43] Update verify test tolerance --- test/verify/test_mxfp4_gemm.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 30f84c6e049..0a2f1a7c73c 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -126,4 +126,6 @@ struct test_mxfp4_gemm : verify_program return p; } std::string section() const { return "gemm"; } + + std::size_t get_tolerance() const { return 4e5; }; }; From 3b7cc58d15a1daf63c1f26b9525887c69b21e86a Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 15 Oct 2025 15:11:50 -0500 Subject: [PATCH 27/43] Fix tests and change flag behavior --- src/driver/main.cpp | 2 +- src/driver/verify.cpp | 6 +- src/driver/verify_options.hpp | 2 +- src/include/migraphx/simplify_qdq.hpp | 4 +- .../gpu/include/migraphx/gpu/device_name.hpp | 1 - src/targets/gpu/mlir.cpp | 10 +-- src/targets/gpu/target.cpp | 2 +- src/targets/ref/target.cpp | 4 +- test/simplify_qdq_test.cpp | 2 +- test/verify/test_mxfp4_gemm.cpp | 75 ++++++++++++------- 10 files changed, 63 insertions(+), 45 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 4f402e06b5d..e794fefabe7 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -730,7 +730,7 @@ struct verify : command ap(bisect, {"-b", "--bisect"}, ap.help("Bisect program and verify"), ap.set_value(true)); ap(vo.ref_use_double, {"--ref-use-double"}, - ap.help("Convert floating point values to double on ref"), + ap.help("Convert floating point values to double on ref. Also removes Q/DQ pairs on ref."), ap.set_value(true)); ap(vo.compiled_model, {"--compiled-model", "-c"}, ap.help("Compiled model to use")); } diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 009b454cb45..142d43cf90b 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -97,7 +97,11 @@ static std::vector run_ref(program p, { if(vo.ref_use_double) { - run_passes(p, {fp_to_double{}}); + run_passes(p, { + fp_to_double{}, + simplify_qdq{.remove_qdq_only = true}, + dead_code_elimination{} + }); } p.compile(migraphx::make_target("ref"), options); auto out = p.eval(inputs); diff --git a/src/driver/verify_options.hpp b/src/driver/verify_options.hpp index 06343bee5ea..4c83f95f01a 100644 --- a/src/driver/verify_options.hpp +++ b/src/driver/verify_options.hpp @@ -37,7 +37,7 @@ struct verify_options precision quantize = precision::fp32; /** - * Converts floating point values to double on the ref target. + * Converts floating point values to double on the ref target. Also removes Q/DQ pairs on ref. */ bool ref_use_double = false; diff --git a/src/include/migraphx/simplify_qdq.hpp b/src/include/migraphx/simplify_qdq.hpp index 16d32534792..bf932626181 100644 --- a/src/include/migraphx/simplify_qdq.hpp +++ b/src/include/migraphx/simplify_qdq.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ struct module; struct MIGRAPHX_EXPORT simplify_qdq { bool remove_qdq_only = false; - bool use_mx_quant = false; + bool use_mx_quant = false; std::string name() const { return "simplify_qdq"; } void apply(module& m) const; }; diff --git a/src/targets/gpu/include/migraphx/gpu/device_name.hpp b/src/targets/gpu/include/migraphx/gpu/device_name.hpp index 281688b7ff2..b346aa046b3 100644 --- a/src/targets/gpu/include/migraphx/gpu/device_name.hpp +++ b/src/targets/gpu/include/migraphx/gpu/device_name.hpp @@ -47,7 +47,6 @@ MIGRAPHX_GPU_EXPORT bool gfx_has_mx_intrinsics(); MIGRAPHX_GPU_EXPORT bool gfx_has_fp8fnuz_support(); - #if MIGRAPHX_USE_HIPBLASLT MIGRAPHX_GPU_EXPORT bool gfx_default_rocblas(); #endif diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 9e744a3e63b..a85a3540c98 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -513,9 +513,9 @@ struct mlir_program MlirAttribute segment_sizes_attr = mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data()); MlirNamedAttribute named_attr = mlirNamedAttributeGet( - mlirIdentifierGet(prog->ctx.get(), - mlirStringRefCreateFromCString("operandSegmentSizes")), - segment_sizes_attr); + mlirIdentifierGet(prog->ctx.get(), + mlirStringRefCreateFromCString("operandSegmentSizes")), + segment_sizes_attr); mlirOperationStateAddAttributes(&op_state, 1, &named_attr); } @@ -768,8 +768,8 @@ struct mlir_program ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); if(ins->name() == "quant_dot" and - ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type and - ins->inputs().size() == 4) + ins->inputs().size() == 4 and + ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type) { // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 0fb9376700a..e8f9c19e84b 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -189,7 +189,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}), enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}), - simplify_qdq{.use_mx_quant=gfx_has_mx_intrinsics()}, + simplify_qdq{.use_mx_quant=gpu::gfx_has_mx_intrinsics()}, enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, rewrite_rnn{}, diff --git a/src/targets/ref/target.cpp b/src/targets/ref/target.cpp index 5dde5027c47..0a7f7545a2a 100644 --- a/src/targets/ref/target.cpp +++ b/src/targets/ref/target.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -52,8 +52,6 @@ std::vector target::get_passes(migraphx::context&, const compile_options&) dead_code_elimination{}, rewrite_rnn{}, dead_code_elimination{}, - simplify_qdq{.remove_qdq_only = true}, - dead_code_elimination{}, auto_contiguous{}, dead_code_elimination{}, lowering{}, diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 8d548224762..4e2c2ddcb30 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -44,7 +44,7 @@ static bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot static void run_pass(migraphx::module& m) { - run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}}); + run_passes(m, {migraphx::simplify_qdq{.remove_qdq_only=false, .use_mx_quant=true}, migraphx::dead_code_elimination{}}); } static void run_cse(migraphx::module& m) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 0a2f1a7c73c..0a0756d86b1 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -28,19 +28,25 @@ #include #include -migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size) +#include + +namespace { +migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, + migraphx::instruction_ref input, + int block_axis, + int block_size) { using migraphx::instruction_ref; - using migraphx::module_ref; using migraphx::make_op; + using migraphx::module_ref; // Code similar to that in parse_mxfixneruon // make reduction axes for calculating block scales // tmp_lens != input_lens if runt block is padded - instruction_ref tmp_in = input; - const auto input_lens = input->get_shape().lens(); - auto tmp_lens = input_lens; - auto block_dim = tmp_lens.at(block_axis); + instruction_ref tmp_in = input; + const auto input_lens = input->get_shape().lens(); + auto tmp_lens = input_lens; + auto block_dim = tmp_lens.at(block_axis); std::size_t block_padding = std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim; // handle runt block by padding @@ -56,7 +62,8 @@ migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::i std::vector reduct_dims = tmp_lens; reduct_dims.at(block_axis) = block_size; reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks); - instruction_ref reshape_ins = m->add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); + instruction_ref reshape_ins = + m->add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in); // dynamic quantization for MX types: // V_k = fp32 vector input of block size k @@ -67,22 +74,20 @@ migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::i m->add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins); auto log2_ins = m->add_instruction(make_op("log2"), reduce_max_ins); auto floor_ins = m->add_instruction(make_op("floor"), log2_ins); - auto lit_2_ins = m->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); + auto lit_2_ins = + m->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}}); auto broadcast_lit_2_ins = m->add_instruction( - make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), - lit_2_ins); - auto pow_ins = m->add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); - auto lit_4_ins = m->add_literal( - migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), lit_2_ins); + auto pow_ins = m->add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins); + auto lit_4_ins = + m->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}}); auto broadcast_lit_4_ins = m->add_instruction( - make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), - lit_4_ins); + make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}), lit_4_ins); auto block_scales_ins = m->add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins); // broadcast scales for use in quantizelinear - block_scales_ins = m->add_instruction( - make_op("multibroadcast", {{"out_lens", reduct_dims}}), block_scales_ins); + block_scales_ins = m->add_instruction(make_op("multibroadcast", {{"out_lens", reduct_dims}}), + block_scales_ins); block_scales_ins = m->add_instruction(make_op("reshape", {{"dims", tmp_lens}}), block_scales_ins); @@ -91,11 +96,12 @@ migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::i { std::size_t slice_size = input_lens.at(block_axis); block_scales_ins = m->add_instruction( - make_op("slice", {{"axes", {block_axis}}, {"starts", {0}}, {"ends", {slice_size}}}), - block_scales_ins); + make_op("slice", {{"axes", {block_axis}}, {"starts", {0}}, {"ends", {slice_size}}}), + block_scales_ins); } return block_scales_ins; } +} // namespace /** * Designed to be like the final GEMM of resnet50. @@ -106,26 +112,37 @@ struct test_mxfp4_gemm : verify_program { migraphx::program p; migraphx::module_ref mmain = p.get_main_module(); - auto input = mmain->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {1, 2048}}); + auto input = + mmain->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {1, 2048}}); auto input_scales = add_dyn_scale_calc(mmain, input, 1, 32); - input = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), input, input_scales); + input = mmain->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + input, + input_scales); input = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), input); input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); input = mmain->add_instruction(migraphx::make_op("dequantizelinear"), input, input_scales); - auto weights = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 2)); + auto weights = mmain->add_literal(migraphx::generate_literal( + migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 2)); auto weight_scales = add_dyn_scale_calc(mmain, weights, 1, 32); - weights = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), weights, weight_scales); + weights = mmain->add_instruction( + migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), + weights, + weight_scales); weights = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), weights); weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); - weights = mmain->add_instruction(migraphx::make_op("dequantizelinear"), weights, weight_scales); - weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights); - auto bias = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); - auto dot = mmain->add_instruction(migraphx::make_op("dot"), input, weights); + weights = + mmain->add_instruction(migraphx::make_op("dequantizelinear"), weights, weight_scales); + weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), + weights); + auto bias = mmain->add_literal( + migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0)); + auto dot = mmain->add_instruction(migraphx::make_op("dot"), input, weights); auto bias_add = mmain->add_instruction(migraphx::make_op("add"), dot, bias); mmain->add_return({bias_add}); return p; } std::string section() const { return "gemm"; } - std::size_t get_tolerance() const { return 4e5; }; + //std::size_t get_tolerance() const { return 4e5; }; }; From 8e97a82af4889d5ff96c5c1faa75e3a16a3e4c0a Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 15 Oct 2025 15:13:39 -0500 Subject: [PATCH 28/43] Typo fix --- src/driver/verify.cpp | 2 +- src/targets/gpu/mlir.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 142d43cf90b..34f1f8fa8d7 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -44,7 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS { /** * Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults. * Sets to fp4 tolerances if any fp4x2_type is found. - * Else sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the + * Else sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction is found in the * model. */ verify::tolerance get_tolerances(const program& p, diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index a85a3540c98..c2ca400f5c6 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,7 +508,7 @@ struct mlir_program { } - void set_operand_segement_sizes(int num_segments, const std::vector& sizes) + void set_operand_segment_sizes(int num_segments, const std::vector& sizes) { MlirAttribute segment_sizes_attr = mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data()); @@ -774,7 +774,7 @@ struct mlir_program // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. const std::vector seg_sizes = {1, 1, 1, 1}; - ops.set_operand_segement_sizes(4, seg_sizes); + ops.set_operand_segment_sizes(4, seg_sizes); } ops.add_operands(inputs); From f4c30b95c22398f0558a424e80f1c2dd8928bf78 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 15 Oct 2025 16:12:27 -0500 Subject: [PATCH 29/43] Typo and include fixes --- src/driver/verify.cpp | 1 + src/targets/ref/target.cpp | 1 - test/verify/test_mxfp4_gemm.cpp | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 34f1f8fa8d7..a51d21c8b24 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include namespace migraphx { diff --git a/src/targets/ref/target.cpp b/src/targets/ref/target.cpp index 0a7f7545a2a..66f30f3de60 100644 --- a/src/targets/ref/target.cpp +++ b/src/targets/ref/target.cpp @@ -35,7 +35,6 @@ #include #include #include -#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 0a0756d86b1..116bf62a091 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -40,7 +40,7 @@ migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, using migraphx::make_op; using migraphx::module_ref; - // Code similar to that in parse_mxfixneruon + // Code similar to that in parse_mxfixneuron // make reduction axes for calculating block scales // tmp_lens != input_lens if runt block is padded instruction_ref tmp_in = input; From 6940e9f763a148362bd2cebca74e4e2c7bf2a451 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 15 Oct 2025 16:14:17 -0500 Subject: [PATCH 30/43] formatting --- src/driver/main.cpp | 3 ++- src/driver/verify.cpp | 7 ++----- src/targets/gpu/mlir.cpp | 3 +-- test/simplify_qdq_test.cpp | 6 ++++-- test/verify/test_mxfp4_gemm.cpp | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/driver/main.cpp b/src/driver/main.cpp index e794fefabe7..a5ea6576b35 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -730,7 +730,8 @@ struct verify : command ap(bisect, {"-b", "--bisect"}, ap.help("Bisect program and verify"), ap.set_value(true)); ap(vo.ref_use_double, {"--ref-use-double"}, - ap.help("Convert floating point values to double on ref. Also removes Q/DQ pairs on ref."), + ap.help( + "Convert floating point values to double on ref. Also removes Q/DQ pairs on ref."), ap.set_value(true)); ap(vo.compiled_model, {"--compiled-model", "-c"}, ap.help("Compiled model to use")); } diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index a51d21c8b24..6ac5d3dcb24 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -98,11 +98,8 @@ static std::vector run_ref(program p, { if(vo.ref_use_double) { - run_passes(p, { - fp_to_double{}, - simplify_qdq{.remove_qdq_only = true}, - dead_code_elimination{} - }); + run_passes( + p, {fp_to_double{}, simplify_qdq{.remove_qdq_only = true}, dead_code_elimination{}}); } p.compile(migraphx::make_target("ref"), options); auto out = p.eval(inputs); diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index c2ca400f5c6..a454636a945 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -767,8 +767,7 @@ struct mlir_program transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); - if(ins->name() == "quant_dot" and - ins->inputs().size() == 4 and + if(ins->name() == "quant_dot" and ins->inputs().size() == 4 and ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type) { // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 4e2c2ddcb30..160a457233e 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -44,7 +44,9 @@ static bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot static void run_pass(migraphx::module& m) { - run_passes(m, {migraphx::simplify_qdq{.remove_qdq_only=false, .use_mx_quant=true}, migraphx::dead_code_elimination{}}); + run_passes(m, + {migraphx::simplify_qdq{.remove_qdq_only = false, .use_mx_quant = true}, + migraphx::dead_code_elimination{}}); } static void run_cse(migraphx::module& m) @@ -1456,7 +1458,7 @@ TEST_CASE(dot_reused) auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens()); auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2); auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1]); - auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); + auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3); m2.add_return({add2}); } diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 116bf62a091..1ff2be7c1cd 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -144,5 +144,5 @@ struct test_mxfp4_gemm : verify_program } std::string section() const { return "gemm"; } - //std::size_t get_tolerance() const { return 4e5; }; + // std::size_t get_tolerance() const { return 4e5; }; }; From 387114bf9a930abf196cc99b71e385afb9da71c9 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Thu, 16 Oct 2025 21:11:08 +0000 Subject: [PATCH 31/43] Fix mlir compilation for operand_size --- src/targets/gpu/mlir.cpp | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index c2ca400f5c6..a5ff37b2c78 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -766,16 +766,25 @@ struct mlir_program std::vector inputs; transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); - - if(ins->name() == "quant_dot" and - ins->inputs().size() == 4 and - ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type) - { - // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. - // Use the canonical MLIR attribute name 'operandSegmentSizes'. - const std::vector seg_sizes = {1, 1, 1, 1}; - ops.set_operand_segment_sizes(4, seg_sizes); - } + if(ins->name() == "dot") { + const std::vector seg_sizes = {1, 1, 0, 0}; + ops.set_operand_segment_sizes(4, seg_sizes); + } + else if(ins->name() == "quant_dot") + { + if(ins->inputs().size() == 4 and ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type) + { + // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. + // Use the canonical MLIR attribute name 'operandSegmentSizes'. + const std::vector seg_sizes = {1, 1, 1, 1}; + ops.set_operand_segment_sizes(4, seg_sizes); + } + else if(ins->inputs().size() == 2) + { + const std::vector seg_sizes = {1, 1, 0, 0}; + ops.set_operand_segment_sizes(4, seg_sizes); + } + } ops.add_operands(inputs); auto outputs = insert(fbody, std::move(ops)); From 991ca0ddef14ea2be301841bb5eee95e44163b35 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 16 Oct 2025 16:15:38 -0500 Subject: [PATCH 32/43] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8fe901d0795..f9008654016 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). +## Develop Branch + +### Added +* Added MXFP4 support for Quark and Brevitas quantized models (GEMMs only) (#4343) + ## MIGraphX 2.14 for ROCm 7.1.0 From 92aea679b36a4a968f1d87916538dd7a508a0de8 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 17 Oct 2025 16:19:49 -0500 Subject: [PATCH 33/43] Add dead_code_elim dependency --- src/driver/verify.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/driver/verify.cpp b/src/driver/verify.cpp index 6ac5d3dcb24..ff201f0a4db 100644 --- a/src/driver/verify.cpp +++ b/src/driver/verify.cpp @@ -36,6 +36,7 @@ #include #include #include +#include #include namespace migraphx { From 4b2c6c0302ca22c9522786ddaf76ecc38595f843 Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 21 Oct 2025 12:20:50 -0500 Subject: [PATCH 34/43] Add back tolerance for verify test Doesn't pass with not-MI350 otherwise --- test/verify/test_mxfp4_gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 1ff2be7c1cd..152a6fc3c70 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -144,5 +144,5 @@ struct test_mxfp4_gemm : verify_program } std::string section() const { return "gemm"; } - // std::size_t get_tolerance() const { return 4e5; }; + std::size_t get_tolerance() const { return 4e5; }; }; From f6ea7e0a968540f3b86c306503f166528bcde876 Mon Sep 17 00:00:00 2001 From: charlie Date: Wed, 29 Oct 2025 14:24:36 -0500 Subject: [PATCH 35/43] AIMIGRAPHX-193 use vector sizes --- src/targets/gpu/mlir.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index a5ff37b2c78..e59a96c85e6 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -508,10 +508,10 @@ struct mlir_program { } - void set_operand_segment_sizes(int num_segments, const std::vector& sizes) + void set_operand_segment_sizes(const std::vector& sizes) { MlirAttribute segment_sizes_attr = - mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data()); + mlirDenseI32ArrayGet(prog->ctx.get(), sizes.size(), sizes.data()); MlirNamedAttribute named_attr = mlirNamedAttributeGet( mlirIdentifierGet(prog->ctx.get(), mlirStringRefCreateFromCString("operandSegmentSizes")), @@ -768,7 +768,7 @@ struct mlir_program ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); if(ins->name() == "dot") { const std::vector seg_sizes = {1, 1, 0, 0}; - ops.set_operand_segment_sizes(4, seg_sizes); + ops.set_operand_segment_sizes(seg_sizes); } else if(ins->name() == "quant_dot") { @@ -777,12 +777,12 @@ struct mlir_program // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. const std::vector seg_sizes = {1, 1, 1, 1}; - ops.set_operand_segment_sizes(4, seg_sizes); + ops.set_operand_segment_sizes(seg_sizes); } else if(ins->inputs().size() == 2) { const std::vector seg_sizes = {1, 1, 0, 0}; - ops.set_operand_segment_sizes(4, seg_sizes); + ops.set_operand_segment_sizes(seg_sizes); } } ops.add_operands(inputs); From a0806545c12af6c9140138a8db0cde8446798d7a Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 31 Oct 2025 11:06:56 -0500 Subject: [PATCH 36/43] Update mlir tests --- test/gpu/mlir.cpp | 66 +++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 387b323e4f2..ef6c7483b99 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -304,39 +304,39 @@ module { // EXPECT(verify_mlir(m)); } -TEST_CASE(conv_backwards) -{ - std::string mlir_output = R"__migraphx__( -module { - func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes ${attrs} { - %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> - return %0 : !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> - } -} -)__migraphx__"; - - migraphx::module m; - auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); - auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); - auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); - m.add_return({conv_b}); - - auto s = migraphx::gpu::dump_mlir(m); - // Skip test if MLIR is not enabled - if(s.empty()) - return; - auto mlir_output_with_attrs = - migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); - CHECK(encode(s) == encode(mlir_output_with_attrs)); - EXPECT(verify_mlir(m)); -} +//TEST_CASE(conv_backwards) +//{ +// std::string mlir_output = R"__migraphx__( +//module { +// func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes ${attrs} { +// %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> +// return %0 : !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> +// } +//} +//)__migraphx__"; +// +// migraphx::module m; +// auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); +// auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); +// auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); +// m.add_return({conv_b}); +// +// auto s = migraphx::gpu::dump_mlir(m); +// // Skip test if MLIR is not enabled +// if(s.empty()) +// return; +// auto mlir_output_with_attrs = +// migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); +// CHECK(encode(s) == encode(mlir_output_with_attrs)); +// EXPECT(verify_mlir(m)); +//} TEST_CASE(quant_dot_add) { std::string mlir_output = R"__migraphx__( module { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xsi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xsi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi32, 15x3x1>) -> !migraphx.shaped<1x5x3xsi32, 15x3x1> attributes ${attrs} { - %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> + %0 = migraphx.quant_dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xsi32, 15x3x1>, <1x5x3xsi32, 15x3x1> -> <1x5x3xsi32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xsi32, 15x3x1> } @@ -365,7 +365,7 @@ TEST_CASE(dot_add) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -394,7 +394,7 @@ TEST_CASE(unsqueeze_dot_add) module { func.func @mlir_unsqueeze_dot_add(%arg0: !migraphx.shaped<5x4xf32, 4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.reshape %arg0 {dims = [1, 5, 4]} : <5x4xf32, 4x1> -> <1x5x4xf32, 20x4x1> - %1 = migraphx.dot %0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %1 = migraphx.dot %0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %2 = migraphx.add %1, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %2 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -458,7 +458,7 @@ TEST_CASE(dot_convert) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1> } @@ -486,7 +486,7 @@ TEST_CASE(dot_where) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xsi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -579,7 +579,7 @@ module { %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xsi8, 12x2x1> -> <2x5x2xsi8, 12x2x1> %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xsi8, 5x1x1> -> <2x5x2xsi8, 10x2x1> %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xsi8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xsi8, 12x2x1> -> <2x5x2xf32, 10x2x1> - %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + %10 = migraphx.dot %arg0, %9 {operandSegmentSizes = array} : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> } } @@ -635,7 +635,7 @@ module { %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xui8, 12x2x1> -> <2x5x2xui8, 12x2x1> %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xui8, 5x1x1> -> <2x5x2xui8, 10x2x1> %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xui8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xui8, 12x2x1> -> <2x5x2xf32, 10x2x1> - %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + %10 = migraphx.dot %arg0, %9 {operandSegmentSizes = array} : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> } } From 4bdd67592bff7162c32d6b9c22027222ffce0ad5 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 31 Oct 2025 11:09:39 -0500 Subject: [PATCH 37/43] Remove fp4 check on quant_dot rocmlir compile --- src/targets/gpu/mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index e59a96c85e6..180a42f0baf 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -772,7 +772,7 @@ struct mlir_program } else if(ins->name() == "quant_dot") { - if(ins->inputs().size() == 4 and ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type) + if(ins->inputs().size() == 4) { // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. // Use the canonical MLIR attribute name 'operandSegmentSizes'. From 6c56366b0d29de63fc073be7c605cb5038adec06 Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 31 Oct 2025 15:25:59 -0500 Subject: [PATCH 38/43] Formatting --- src/targets/gpu/mlir.cpp | 39 ++++++++++++++++---------------- test/gpu/mlir.cpp | 48 ++++++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 180a42f0baf..2515ddb66bf 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -766,25 +766,26 @@ struct mlir_program std::vector inputs; transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); - if(ins->name() == "dot") { - const std::vector seg_sizes = {1, 1, 0, 0}; - ops.set_operand_segment_sizes(seg_sizes); - } - else if(ins->name() == "quant_dot") - { - if(ins->inputs().size() == 4) - { - // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. - // Use the canonical MLIR attribute name 'operandSegmentSizes'. - const std::vector seg_sizes = {1, 1, 1, 1}; - ops.set_operand_segment_sizes(seg_sizes); - } - else if(ins->inputs().size() == 2) - { - const std::vector seg_sizes = {1, 1, 0, 0}; - ops.set_operand_segment_sizes(seg_sizes); - } - } + if(ins->name() == "dot") + { + const std::vector seg_sizes = {1, 1, 0, 0}; + ops.set_operand_segment_sizes(seg_sizes); + } + else if(ins->name() == "quant_dot") + { + if(ins->inputs().size() == 4) + { + // Specify operand segment sizes BEFORE creating the operation so MLIR sees it. + // Use the canonical MLIR attribute name 'operandSegmentSizes'. + const std::vector seg_sizes = {1, 1, 1, 1}; + ops.set_operand_segment_sizes(seg_sizes); + } + else if(ins->inputs().size() == 2) + { + const std::vector seg_sizes = {1, 1, 0, 0}; + ops.set_operand_segment_sizes(seg_sizes); + } + } ops.add_operands(inputs); auto outputs = insert(fbody, std::move(ops)); diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index ef6c7483b99..2e551ce4dba 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -304,32 +304,36 @@ module { // EXPECT(verify_mlir(m)); } -//TEST_CASE(conv_backwards) +// TEST_CASE(conv_backwards) //{ -// std::string mlir_output = R"__migraphx__( -//module { -// func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes ${attrs} { -// %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> -// return %0 : !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> -// } -//} +// std::string mlir_output = R"__migraphx__( +// module { +// func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: +// !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes +// ${attrs} { +// %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, +// padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, +// <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> return %0 : !migraphx.shaped<1x1x5x5xf32, +// 25x25x5x1> +// } +// } //)__migraphx__"; // -// migraphx::module m; -// auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); -// auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); -// auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); -// m.add_return({conv_b}); +// migraphx::module m; +// auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, +// 3}}); auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, +// 3, 3}}); auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); +// m.add_return({conv_b}); // -// auto s = migraphx::gpu::dump_mlir(m); -// // Skip test if MLIR is not enabled -// if(s.empty()) -// return; -// auto mlir_output_with_attrs = -// migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); -// CHECK(encode(s) == encode(mlir_output_with_attrs)); -// EXPECT(verify_mlir(m)); -//} +// auto s = migraphx::gpu::dump_mlir(m); +// // Skip test if MLIR is not enabled +// if(s.empty()) +// return; +// auto mlir_output_with_attrs = +// migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); +// CHECK(encode(s) == encode(mlir_output_with_attrs)); +// EXPECT(verify_mlir(m)); +// } TEST_CASE(quant_dot_add) { From 84cef3984fec447dea8fde9192dd80760fbc6812 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 6 Nov 2025 13:50:48 -0600 Subject: [PATCH 39/43] Remove dot operand sizes and tests --- src/targets/gpu/mlir.cpp | 7 +--- test/gpu/mlir.cpp | 74 ++++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 2515ddb66bf..6b15f193026 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -766,12 +766,7 @@ struct mlir_program std::vector inputs; transform( ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); }); - if(ins->name() == "dot") - { - const std::vector seg_sizes = {1, 1, 0, 0}; - ops.set_operand_segment_sizes(seg_sizes); - } - else if(ins->name() == "quant_dot") + if(ins->name() == "quant_dot") { if(ins->inputs().size() == 4) { diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 2e551ce4dba..841aa9c398c 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -304,43 +304,43 @@ module { // EXPECT(verify_mlir(m)); } -// TEST_CASE(conv_backwards) -//{ -// std::string mlir_output = R"__migraphx__( -// module { -// func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: -// !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes -// ${attrs} { -// %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, -// padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, -// <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> return %0 : !migraphx.shaped<1x1x5x5xf32, -// 25x25x5x1> -// } -// } -//)__migraphx__"; -// -// migraphx::module m; -// auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, -// 3}}); auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, -// 3, 3}}); auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); -// m.add_return({conv_b}); -// -// auto s = migraphx::gpu::dump_mlir(m); -// // Skip test if MLIR is not enabled -// if(s.empty()) -// return; -// auto mlir_output_with_attrs = -// migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); -// CHECK(encode(s) == encode(mlir_output_with_attrs)); -// EXPECT(verify_mlir(m)); -// } +TEST_CASE(conv_backwards) +{ + std::string mlir_output = R"__migraphx__( + module { + func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: + !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes + ${attrs} { + %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, + padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, + <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> return %0 : !migraphx.shaped<1x1x5x5xf32, + 25x25x5x1> + } + } +)__migraphx__"; + + migraphx::module m; + auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, + 3}}); auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, + 3, 3}}); auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); + m.add_return({conv_b}); + + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} TEST_CASE(quant_dot_add) { std::string mlir_output = R"__migraphx__( module { func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xsi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xsi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi32, 15x3x1>) -> !migraphx.shaped<1x5x3xsi32, 15x3x1> attributes ${attrs} { - %0 = migraphx.quant_dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> + %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xsi32, 15x3x1>, <1x5x3xsi32, 15x3x1> -> <1x5x3xsi32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xsi32, 15x3x1> } @@ -369,7 +369,7 @@ TEST_CASE(dot_add) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_add(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.add %0, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -398,7 +398,7 @@ TEST_CASE(unsqueeze_dot_add) module { func.func @mlir_unsqueeze_dot_add(%arg0: !migraphx.shaped<5x4xf32, 4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.reshape %arg0 {dims = [1, 5, 4]} : <5x4xf32, 4x1> -> <1x5x4xf32, 20x4x1> - %1 = migraphx.dot %0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %1 = migraphx.dot %0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %2 = migraphx.add %1, %arg2 : <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %2 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -462,7 +462,7 @@ TEST_CASE(dot_convert) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_convert(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped<1x5x3xf16, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.convert %0 {target_type = 1 : i64} : <1x5x3xf32, 15x3x1> to <1x5x3xf16, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf16, 15x3x1> } @@ -490,7 +490,7 @@ TEST_CASE(dot_where) std::string mlir_output = R"__migraphx__( module { func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { - %0 = migraphx.dot %arg0, %arg1 {operandSegmentSizes = array} : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> + %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xsi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } @@ -583,7 +583,7 @@ module { %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xsi8, 12x2x1> -> <2x5x2xsi8, 12x2x1> %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xsi8, 5x1x1> -> <2x5x2xsi8, 10x2x1> %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xsi8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xsi8, 12x2x1> -> <2x5x2xf32, 10x2x1> - %10 = migraphx.dot %arg0, %9 {operandSegmentSizes = array} : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> } } @@ -639,7 +639,7 @@ module { %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xui8, 12x2x1> -> <2x5x2xui8, 12x2x1> %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xui8, 5x1x1> -> <2x5x2xui8, 10x2x1> %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xui8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xui8, 12x2x1> -> <2x5x2xf32, 10x2x1> - %10 = migraphx.dot %arg0, %9 {operandSegmentSizes = array} : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> } } From 0a7c672e5b8bf433d8bd450e096f70b868687ead Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 6 Nov 2025 13:54:33 -0600 Subject: [PATCH 40/43] Fix whitespace --- test/gpu/mlir.cpp | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 841aa9c398c..244fab78ed2 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -307,22 +307,18 @@ module { TEST_CASE(conv_backwards) { std::string mlir_output = R"__migraphx__( - module { - func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: - !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes - ${attrs} { - %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, - padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, - <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> return %0 : !migraphx.shaped<1x1x5x5xf32, - 25x25x5x1> - } - } -)__migraphx__"; + module { + func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes ${attrs} { + %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> + return %0 : !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> + } + } + )__migraphx__"; migraphx::module m; - auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, - 3}}); auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, - 3, 3}}); auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); + auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); + auto w = m.add_parameter("w", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); + auto conv_b = m.add_instruction(migraphx::make_op("convolution_backwards"), x, w); m.add_return({conv_b}); auto s = migraphx::gpu::dump_mlir(m); From 55565157857e13ee5e1faf1cd47fafbda006ca57 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 6 Nov 2025 13:55:56 -0600 Subject: [PATCH 41/43] More whitespace --- test/gpu/mlir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 244fab78ed2..387b323e4f2 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -307,13 +307,13 @@ module { TEST_CASE(conv_backwards) { std::string mlir_output = R"__migraphx__( - module { +module { func.func @mlir_convolution_backwards(%arg0: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>, %arg1: !migraphx.shaped<1x1x3x3xf32, 9x9x3x1>) -> !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> attributes ${attrs} { %0 = migraphx.backwards_data_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x1x3x3xf32, 9x9x3x1>, <1x1x3x3xf32, 9x9x3x1> -> <1x1x5x5xf32, 25x25x5x1> return %0 : !migraphx.shaped<1x1x5x5xf32, 25x25x5x1> - } - } - )__migraphx__"; + } +} +)__migraphx__"; migraphx::module m; auto x = m.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 3}}); From 1f3ac4e68bb209b620c8351e443011086b789965 Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Fri, 7 Nov 2025 19:43:20 -0500 Subject: [PATCH 42/43] Add rocMLIR updates --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ae146d8614c..a6274be4856 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@fe6da4db4d6f0da8c74e28a0787cfbb4a026550a -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@fa1136f4d5e8580bce5284f997f06c8738131066 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off From 50adffdd432f7ed59817d3a0ff7ee3e2e7f2ebe8 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Fri, 7 Nov 2025 21:22:34 -0500 Subject: [PATCH 43/43] MXFP4: rocMLIR compilation remove unpack_fp4 and bugfixes (#4409) --- src/include/migraphx/op/pack_fp4.hpp | 37 ++++---- src/include/migraphx/op/unpack_fp4.hpp | 31 +++---- src/onnx/parse_mxfixneuron.cpp | 6 +- src/onnx/parse_quantizelinear.cpp | 8 +- src/simplify_algebra.cpp | 3 +- src/targets/gpu/jit/pack_fp4.cpp | 6 +- src/targets/gpu/jit/unpack_fp4.cpp | 6 +- src/targets/gpu/mlir.cpp | 87 ++++++++++++++----- test/onnx/parse/mxfixneuron_test.cpp | 8 +- .../parse/quantizelinear_mx_type_test.cpp | 8 +- test/ref/pack_unpack_fp4.cpp | 4 +- test/simplify_qdq_test.cpp | 52 +++++------ test/verify/test_mxfp4_gemm.cpp | 8 +- test/verify/test_pack_fp4.cpp | 7 +- test/verify/test_unpack_fp4.cpp | 8 +- 15 files changed, 151 insertions(+), 128 deletions(-) diff --git a/src/include/migraphx/op/pack_fp4.hpp b/src/include/migraphx/op/pack_fp4.hpp index ade6bd85d37..3a203e30877 100644 --- a/src/include/migraphx/op/pack_fp4.hpp +++ b/src/include/migraphx/op/pack_fp4.hpp @@ -34,42 +34,35 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { namespace op { +/** + * Packs fastest dimension of tensor into fp4x2_type such that the + * output dimensions are [x_0, ..., x_pack/2, ...] + */ struct pack_fp4 { - int64_t axis = -1; - std::string name() const { return "pack_fp4"; } - value attributes() const - { - value normalize = value::object{}; - normalize["axis"] = value::array{normalize_attribute::include_min}; - return {{"normalize_axes", normalize}}; - } - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.axis, "axis")); - } - migraphx::shape normalize_compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.same_dims().has(1); const auto& in_shape = inputs.front(); - auto new_lens = in_shape.lens(); - if(new_lens[axis] % 2 != 0) + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); + auto new_lens = in_shape.lens(); + if(new_lens.at(fast_axis) % 2 != 0) { - MIGRAPHX_THROW("PACK_FP4: Can not pack axis that has odd lengths"); + MIGRAPHX_THROW("PACK_FP4: Fast dimension is odd, cannot pack"); } - new_lens[axis] /= 2; - return {migraphx::shape::fp4x2_type, new_lens}; + new_lens[fast_axis] /= 2; + return in_shape.with_lens(migraphx::shape::fp4x2_type, new_lens); } argument compute(const shape& output_shape, const std::vector& args) const { const auto& input = args.front(); auto in_shape = input.get_shape(); + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); argument result{output_shape}; auto out = result.get(); @@ -78,9 +71,9 @@ struct pack_fp4 using inp_type = typename decltype(inp)::value_type; auto data_idx = output_shape.multi(i); auto in_data_multi_idx = data_idx; - in_data_multi_idx[axis] *= 2; + in_data_multi_idx[fast_axis] *= 2; inp_type inp_val0 = inp[in_data_multi_idx]; - in_data_multi_idx[axis] += 1; + in_data_multi_idx[fast_axis] += 1; inp_type inp_val1 = inp[in_data_multi_idx]; uint8_t out_val0 = cast_to_fp4(inp_val0); uint8_t out_val1 = cast_to_fp4(inp_val1); diff --git a/src/include/migraphx/op/unpack_fp4.hpp b/src/include/migraphx/op/unpack_fp4.hpp index 3af4c69a094..c00292f0075 100644 --- a/src/include/migraphx/op/unpack_fp4.hpp +++ b/src/include/migraphx/op/unpack_fp4.hpp @@ -35,26 +35,15 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +/** + * Unpacks fastest dimension of tensor into fp8e4m3fn_type such that the + * output dimensions are [x_0, ..., 2 * x_pack, ...] + */ namespace op { struct unpack_fp4 { - int64_t axis = -1; - std::string name() const { return "unpack_fp4"; } - value attributes() const - { - value normalize = value::object{}; - normalize["axis"] = value::array{normalize_attribute::include_min}; - return {{"normalize_axes", normalize}}; - } - - template - static auto reflect(Self& self, F f) - { - return pack(f(self.axis, "axis")); - } - migraphx::shape normalize_compute_shape(std::vector inputs) const { check_shapes{inputs, *this}.same_dims().has(1); @@ -64,14 +53,18 @@ struct unpack_fp4 MIGRAPHX_THROW("UNPACK_FP4: Only fp4x2_type is supported for unpacking"); } auto new_lens = in_shape.lens(); - new_lens[axis] *= 2; - return {migraphx::shape::fp8e4m3fn_type, new_lens}; + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); + new_lens[fast_axis] *= 2; + return in_shape.with_lens(migraphx::shape::fp8e4m3fn_type, new_lens); } argument compute(const shape& output_shape, const std::vector& args) const { const auto& input = args.front(); auto in_shape = input.get_shape(); + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); migraphx::shape fp8_shape = shape{migraphx::shape::fp8e4m3fn_type, output_shape.lens()}; argument fp8_arg{fp8_shape}; @@ -79,13 +72,13 @@ struct unpack_fp4 fp8_arg.visit([&](auto out) { par_for(in_shape.elements(), [&](auto i) { auto data_idx = in_shape.multi(i); - data_idx[axis] *= 2; + data_idx[fast_axis] *= 2; // unpacking 2 unsigned parts // unpacking 4 least significant bits first uint8_t fp4_val = inp[i]; out[data_idx] = fp4_to_fp8(fp4_val); - data_idx[axis] += 1; + data_idx[fast_axis] += 1; fp4_val = fp4_val >> 4u; out[data_idx] = fp4_to_fp8(fp4_val); }); diff --git a/src/onnx/parse_mxfixneuron.cpp b/src/onnx/parse_mxfixneuron.cpp index af4c129484b..211e8f8de0f 100644 --- a/src/onnx/parse_mxfixneuron.cpp +++ b/src/onnx/parse_mxfixneuron.cpp @@ -34,7 +34,7 @@ namespace onnx { struct parse_mxfixneuron : op_parser { - std::vector operators() const { return {{"MXFixNeuron"}}; } + std::vector operators() const { return {{"MXFixNeuron"}, {"MXQuantizeDequantize"}}; } instruction_ref parse(const op_desc& /*opd*/, const onnx_parser& /*parser*/, @@ -147,9 +147,9 @@ struct parse_mxfixneuron : op_parser padding.at(fast_axis * 2 + 1) = 1; q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins); } - auto pack_ins = info.add_instruction(make_op("pack_fp4", {{"axis", fast_axis}}), + auto pack_ins = info.add_instruction(make_op("pack_fp4"), q_ins); // output is fp4x2_type - auto unpack_ins = info.add_instruction(make_op("unpack_fp4", {{"axis", fast_axis}}), + auto unpack_ins = info.add_instruction(make_op("unpack_fp4"), pack_ins); // output is fp8e4m3fn_type if(odd_fast_axis) { diff --git a/src/onnx/parse_quantizelinear.cpp b/src/onnx/parse_quantizelinear.cpp index 03a6395308f..91ae4de6692 100644 --- a/src/onnx/parse_quantizelinear.cpp +++ b/src/onnx/parse_quantizelinear.cpp @@ -114,10 +114,10 @@ struct parse_quantizelinear : op_parser padding.at(fast_axis * 2 + 1) = 1; q_ins = info.add_instruction(make_op("pad", {{"pads", padding}}), q_ins); } - auto pack_ins = info.add_instruction(make_op("pack_fp4", {{"axis", fast_axis}}), - q_ins); // output is fp4x2_type - auto unpack_ins = info.add_instruction(make_op("unpack_fp4", {{"axis", fast_axis}}), - pack_ins); // output is fp8e4m3fn_type + // output is fp4x2_type + auto pack_ins = info.add_instruction(make_op("pack_fp4"), q_ins); + // output is fp8e4m3fn_type + auto unpack_ins = info.add_instruction(make_op("unpack_fp4"), pack_ins); if(odd_fast_axis) { // slice off padded values diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 3af49319781..5199b850f4f 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -1633,10 +1633,11 @@ struct find_add_convs MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) { + // checking size to prevent matching block quantized quant_dot for now auto pred = [&](auto name) { return [=](auto i) { return i->name() == name and i->inputs().front() == ins and - i->inputs().at(1)->can_eval(); + i->inputs().at(1)->can_eval() and i->inputs().size() == 2; }; }; auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); diff --git a/src/targets/gpu/jit/pack_fp4.cpp b/src/targets/gpu/jit/pack_fp4.cpp index 1f0b4450762..0cac85980a1 100644 --- a/src/targets/gpu/jit/pack_fp4.cpp +++ b/src/targets/gpu/jit/pack_fp4.cpp @@ -68,12 +68,16 @@ struct pack_fp4_compiler : compiler options.kernel_name = "pack_fp4_kernel"; options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements())); + const auto& in_shape = inputs.front(); + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); + auto src = interpolate_string(pack_fp4_kernel, {{"kernel", options.kernel_name}, {"params", enum_params(options.inputs.size(), "void * private_p")}, {"args", enum_params(options.inputs.size(), "private_p")}, - {"axis", std::to_string(v.at("axis").to())}}); + {"axis", std::to_string(fast_axis)}}); return compile_hip_code_object(ctx, src, options); } diff --git a/src/targets/gpu/jit/unpack_fp4.cpp b/src/targets/gpu/jit/unpack_fp4.cpp index 9720b939a02..b1c2f8a2d55 100644 --- a/src/targets/gpu/jit/unpack_fp4.cpp +++ b/src/targets/gpu/jit/unpack_fp4.cpp @@ -68,12 +68,16 @@ struct unpack_fp4_compiler : compiler options.kernel_name = "unpack_fp4_kernel"; options.set_launch_params(v, compute_global_for(ctx, inputs.front().elements())); + const auto& in_shape = inputs.front(); + int fast_axis = std::min_element(in_shape.strides().cbegin(), in_shape.strides().cend()) - + in_shape.strides().cbegin(); + auto src = interpolate_string(unpack_fp4_kernel, {{"kernel", options.kernel_name}, {"params", enum_params(options.inputs.size(), "void * private_p")}, {"args", enum_params(options.inputs.size(), "private_p")}, - {"axis", std::to_string(v.at("axis").to())}}); + {"axis", std::to_string(fast_axis)}}); return compile_hip_code_object(ctx, src, options); } diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index 6b15f193026..9ca5c35330c 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -317,7 +317,7 @@ struct mlir_program // non-computable type is not visit-able if(t == shape::fp4x2_type) { - return mlirFloat8E4M3FNTypeGet(ctx.get()); + return mlirFloat4E2M1FNTypeGet(ctx.get()); } MlirType result; shape::visit(t, [&](auto as) { @@ -368,8 +368,8 @@ struct mlir_program std::vector make_mlir_shapeds(const Range& r) { std::vector result; - std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& s) { - return make_mlir_shaped(s); + std::transform(r.begin(), r.end(), std::back_inserter(result), [&](const auto& i) { + return make_mlir_shaped(i); }); return result; } @@ -618,22 +618,23 @@ struct mlir_program { auto names = m.get_parameter_names(); std::sort(names.begin(), names.end()); - std::vector inputs; - std::transform(names.begin(), - names.end(), - std::back_inserter(inputs), - [&](const std::string& name) { return m.get_parameter_shape(name); }); + std::vector input_shapes; + std::transform( + names.begin(), + names.end(), + std::back_inserter(input_shapes), + [&](const std::string& name) { return get_shape_for_mlir(m.get_parameter(name)); }); std::vector outputs = m.get_output_shapes(); - std::vector arg_locs(inputs.size(), location); - auto body_inputs = make_mlir_shapeds(inputs); + std::vector arg_locs(input_shapes.size(), location); + auto body_inputs = make_mlir_shapeds(input_shapes); mlir_region region = mlirRegionCreate(); mlir_block fbody = mlirBlockCreate(body_inputs.size(), body_inputs.data(), arg_locs.data()); MlirBlock result = fbody.get(); mlirRegionAppendOwnedBlock(region.get(), fbody.release()); auto ops = create_operation_state("func.func"); - ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, + ops.add_attributes({{"function_type", make_function_type(input_shapes, outputs)}, {"sym_name", sym_name}, {"kernel", std::string("mixr")}, {"arch", target_arch}, @@ -663,8 +664,6 @@ struct mlir_program return "migraphx.literal"; if(ins->name() == "unpack_int4") return "migraphx.unpack"; - if(ins->name() == "unpack_fp4") - return "migraphx.unpack"; if(ins->name() == "convolution_backwards") return "migraphx.backwards_data_convolution"; if(is_reshape(ins->name())) @@ -698,14 +697,54 @@ struct mlir_program return v; } - static shape get_shape(instruction_ref ins) + static bool input_is_unpack_fp4(instruction_ref ins) + { + ins = instruction::get_output_alias(ins); + if(ins->name() == "reshape") + { + return input_is_unpack_fp4(ins->inputs().front()); + } + if(ins->name() == "unpack_fp4") + { + return true; + } + return false; + } + + static shape make_fp4_unpacked_shape(shape s) { + auto new_lens = s.lens(); + auto new_strides = s.strides(); + int fast_axis = + std::min_element(s.strides().cbegin(), s.strides().cend()) - s.strides().cbegin(); + new_lens[fast_axis] *= 2; + for(auto i : range(new_strides.size())) + { + if(i != fast_axis) + { + new_strides.at(i) *= 2; + } + } + return {shape::fp4x2_type, new_lens, new_strides}; + } + + static shape get_shape_for_mlir(instruction_ref ins) + { + shape ret = ins->get_shape(); if(ins->name() == "@return") { assert(ins->inputs().size() == 1); - return ins->inputs().front()->get_shape(); + ret = ins->inputs().front()->get_shape(); + } + else if(input_is_unpack_fp4(ins)) + { + ret = ret.with_type(shape::fp4x2_type); + } + else if(ins->get_shape().type() == shape::fp4x2_type) + { + ret = make_fp4_unpacked_shape(ret); } - return ins->get_shape(); + return ret; } static std::string get_symbol_name(const module& m) @@ -734,7 +773,7 @@ struct mlir_program { if(ins->name() == "@param") continue; - if(ins->name() == "contiguous") + if(contains({"contiguous", "unpack_fp4"}, ins->name())) { ins_map[ins] = ins_map[ins->inputs().at(0)]; continue; @@ -742,15 +781,16 @@ struct mlir_program auto name = get_name(ins); auto ops = create_operation_state(name); ops.add_attribute_value(get_operator_value(ins)); + + // handles single output if(ins->name() != "@return") - ops.add_results({get_shape(ins)}); + ops.add_results({get_shape_for_mlir(ins)}); if(ins->name() == "@literal") { literal r = ins->get_literal(); - auto sh = ins->get_shape(); - MlirType shaped_type = make_mlir_shaped(sh); + MlirType shaped_type = make_mlir_shaped(ins->get_shape()); MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type); MlirAttribute mlir_value_attr = mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data()); @@ -1300,8 +1340,13 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx, mlir_program mp; mp.set_gpu_properties(migraphx_ctx); mp.parse(m); - auto tc = mp.get_tuning_config(exhaustive); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); + if(trace) + { + auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); + std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; + } + auto tc = mp.get_tuning_config(exhaustive); static std::mutex mutex; if(trace) { diff --git a/test/onnx/parse/mxfixneuron_test.cpp b/test/onnx/parse/mxfixneuron_test.cpp index fcbbe2b3e19..e516a325e60 100644 --- a/test/onnx/parse/mxfixneuron_test.cpp +++ b/test/onnx/parse/mxfixneuron_test.cpp @@ -56,8 +56,8 @@ TEST_CASE(mxfixneuron_even_test) input, block_scales_ins); auto quantized_shape = q_ins->get_shape(); - auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), q_ins); - auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4"), q_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4"), pack_ins); mm->add_instruction(migraphx::make_op("dequantizelinear"), unpack_ins, block_scales_ins); auto prog = optimize_onnx("mxfixneuron_even_test.onnx"); @@ -103,8 +103,8 @@ TEST_CASE(mxfixneuron_odd_test) auto quantized_shape = q_ins->get_shape(); auto pad_ins = mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 1}}}), q_ins); - auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 2}}), pad_ins); - auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 2}}), pack_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4"), pad_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4"), pack_ins); auto slice_ins = mm->add_instruction( migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {5}}}), unpack_ins); mm->add_instruction(migraphx::make_op("dequantizelinear"), slice_ins, block_scales_ins); diff --git a/test/onnx/parse/quantizelinear_mx_type_test.cpp b/test/onnx/parse/quantizelinear_mx_type_test.cpp index 4beb1395533..ae6b3a7ee39 100644 --- a/test/onnx/parse/quantizelinear_mx_type_test.cpp +++ b/test/onnx/parse/quantizelinear_mx_type_test.cpp @@ -41,8 +41,8 @@ TEST_CASE(quantizelinear_mxfp4_even_test) migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), l0, l1_reshape); - auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), q_ins); - auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4"), q_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4"), pack_ins); mm->add_return({unpack_ins}); auto prog = read_onnx("quantizelinear_mxfp4_even_test.onnx"); @@ -67,8 +67,8 @@ TEST_CASE(quantizelinear_mxfp4_odd_test) l1_reshape); auto pad_ins = mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 0, 0, 1}}}), q_ins); - auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 3}}), pad_ins); - auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), pack_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4"), pad_ins); + auto unpack_ins = mm->add_instruction(migraphx::make_op("unpack_fp4"), pack_ins); auto slice_ins = mm->add_instruction( migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {7}}}), unpack_ins); mm->add_return({slice_ins}); diff --git a/test/ref/pack_unpack_fp4.cpp b/test/ref/pack_unpack_fp4.cpp index 7ea82bbbac8..128c9fb76fc 100644 --- a/test/ref/pack_unpack_fp4.cpp +++ b/test/ref/pack_unpack_fp4.cpp @@ -75,8 +75,8 @@ TEST_CASE(pack_unpack_fp4) auto* mm = p.get_main_module(); migraphx::shape s{migraphx::shape::float_type, {2, 2}}; auto l0 = mm->add_literal(migraphx::literal{s, {-2.f, 3.4f, 3.5f, 0.f}}); - auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 0}}), l0); - mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 0}}), pack_ins); + auto pack_ins = mm->add_instruction(migraphx::make_op("pack_fp4"), l0); + mm->add_instruction(migraphx::make_op("unpack_fp4"), pack_ins); p.compile(migraphx::make_target("ref")); auto result = p.eval({}).back(); std::vector results_vector(4); diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index 81f6b99ac69..e123499fde7 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1657,9 +1657,9 @@ TEST_CASE(pointwise_concat_quant_per_channel) // auto scale_weights = m1.add_parameter("scale_weights", shape_scale_weights); // // auto unpack_input = -// m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_input); +// m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_input); // auto unpack_weights = -// m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_weights); +// m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_weights); // auto dq_input = // m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_input, scale_input); // auto dq_weights = m1.add_instruction( @@ -1683,9 +1683,9 @@ TEST_CASE(pointwise_concat_quant_per_channel) // auto scale_weights = m2.add_parameter("scale_weights", shape_scale_weights); // // auto unpack_input = -// m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_input); +// m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_input); // auto unpack_weights = -// m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_weights); +// m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_weights); // auto quant_conv = m2.add_instruction(migraphx::make_op("quant_convolution", // {{"padding", {0, 0, 0, 0}}, // {"stride", {1, 1}}, @@ -1719,9 +1719,9 @@ TEST_CASE(pointwise_concat_quant_per_channel) // auto scale_weights = m1.add_parameter("scale_weights", shape_scale_weights); // // auto unpack_input = -// m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_input); +// m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_input); // auto unpack_weights = -// m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_weights); +// m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_weights); // auto slice_input = m1.add_instruction( // migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {21}}}), // unpack_input); @@ -1748,9 +1748,9 @@ TEST_CASE(pointwise_concat_quant_per_channel) // auto scale_weights = m2.add_parameter("scale_weights", shape_scale_weights); // // auto unpack_input = -// m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_input); +// m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_input); // auto unpack_weights = -// m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_weights); +// m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_weights); // auto slice_input = m2.add_instruction( // migraphx::make_op("slice", {{"axes", {3}}, {"starts", {0}}, {"ends", {21}}}), // unpack_input); @@ -1787,10 +1787,8 @@ TEST_CASE(fp4x2_quant_dot_even) auto scale_a = m1.add_parameter("scale_a", shape_scales_a); auto scale_b = m1.add_parameter("scale_b", shape_scales_b); - auto unpack_a = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); @@ -1804,10 +1802,8 @@ TEST_CASE(fp4x2_quant_dot_even) auto scale_a = m2.add_parameter("scale_a", shape_scales_a); auto scale_b = m2.add_parameter("scale_b", shape_scales_b); - auto unpack_a = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto quant_dot = m2.add_instruction( migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); m2.add_return({quant_dot}); @@ -1831,10 +1827,8 @@ TEST_CASE(fp4x2_quant_dot_trans_b) auto scale_a = m1.add_parameter("scale_a", shape_scales_a); auto scale_b = m1.add_parameter("scale_b", shape_scales_b); - auto unpack_a = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); auto trans_b = m1.add_instruction( @@ -1850,10 +1844,8 @@ TEST_CASE(fp4x2_quant_dot_trans_b) auto scale_a = m2.add_parameter("scale_a", shape_scales_a); auto scale_b = m2.add_parameter("scale_b", shape_scales_b); - auto unpack_a = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto trans_b = m2.add_instruction( migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), unpack_b); auto trans_scale_b = m2.add_instruction( @@ -1885,10 +1877,8 @@ TEST_CASE(fp4x2_quant_dot_const_b) auto scale_a = m1.add_parameter("scale_a", shape_scales_a); auto scale_b = m1.add_literal(scale_b_lit); - auto unpack_a = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m1.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m1.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto dq_a = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_a, scale_a); auto dq_b = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack_b, scale_b); auto dot = m1.add_instruction(migraphx::make_op("dot"), dq_a, dq_b); @@ -1902,10 +1892,8 @@ TEST_CASE(fp4x2_quant_dot_const_b) auto scale_a = m2.add_parameter("scale_a", shape_scales_a); auto scale_b = m2.add_literal(scale_b_lit); - auto unpack_a = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_a); - auto unpack_b = - m2.add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 3}}), packed_b); + auto unpack_a = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_a); + auto unpack_b = m2.add_instruction(migraphx::make_op("unpack_fp4"), packed_b); auto quant_dot = m2.add_instruction( migraphx::make_op("quant_dot"), unpack_a, unpack_b, scale_a, scale_b); m2.add_return({quant_dot}); diff --git a/test/verify/test_mxfp4_gemm.cpp b/test/verify/test_mxfp4_gemm.cpp index 152a6fc3c70..9cf151ceb09 100644 --- a/test/verify/test_mxfp4_gemm.cpp +++ b/test/verify/test_mxfp4_gemm.cpp @@ -119,8 +119,8 @@ struct test_mxfp4_gemm : verify_program migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), input, input_scales); - input = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), input); - input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input); + input = mmain->add_instruction(migraphx::make_op("pack_fp4"), input); + input = mmain->add_instruction(migraphx::make_op("unpack_fp4"), input); input = mmain->add_instruction(migraphx::make_op("dequantizelinear"), input, input_scales); auto weights = mmain->add_literal(migraphx::generate_literal( migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 2)); @@ -129,8 +129,8 @@ struct test_mxfp4_gemm : verify_program migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), weights, weight_scales); - weights = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), weights); - weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights); + weights = mmain->add_instruction(migraphx::make_op("pack_fp4"), weights); + weights = mmain->add_instruction(migraphx::make_op("unpack_fp4"), weights); weights = mmain->add_instruction(migraphx::make_op("dequantizelinear"), weights, weight_scales); weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), diff --git a/test/verify/test_pack_fp4.cpp b/test/verify/test_pack_fp4.cpp index 178f5c12e60..fb176e9d6c9 100644 --- a/test/verify/test_pack_fp4.cpp +++ b/test/verify/test_pack_fp4.cpp @@ -27,8 +27,8 @@ #include #include -template -struct test_pack_fp4 : verify_program> +template +struct test_pack_fp4 : verify_program> { migraphx::program create_program() const { @@ -36,10 +36,9 @@ struct test_pack_fp4 : verify_program> auto* mm = p.get_main_module(); auto x = mm->add_parameter("x", migraphx::shape{T, {64, 32}}); - mm->add_instruction(migraphx::make_op("pack_fp4", {{"axis", Axis}}), x); + mm->add_instruction(migraphx::make_op("pack_fp4"), x); return p; } }; template struct test_pack_fp4; -template struct test_pack_fp4; diff --git a/test/verify/test_unpack_fp4.cpp b/test/verify/test_unpack_fp4.cpp index a70b03f412d..50229714e57 100644 --- a/test/verify/test_unpack_fp4.cpp +++ b/test/verify/test_unpack_fp4.cpp @@ -27,8 +27,7 @@ #include #include -template -struct test_unpack_fp4 : verify_program> +struct test_unpack_fp4 : verify_program { migraphx::program create_program() const { @@ -36,11 +35,8 @@ struct test_unpack_fp4 : verify_program> auto* mm = p.get_main_module(); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp4x2_type, {32, 16}}); - mm->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", Axis}}), x); + mm->add_instruction(migraphx::make_op("unpack_fp4"), x); return p; } }; - -template struct test_unpack_fp4<>; -template struct test_unpack_fp4<0>;