Skip to content
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
4afcabd
Fix reshapes propogation for simplify_qdq and constant propagation
CharlieL7 Sep 25, 2025
9cffdf0
Fix more bugs and make tests
CharlieL7 Sep 25, 2025
4d57e7f
Fix reinterpret_cast of the possible r-value reference pointer to con…
CharlieL7 Sep 25, 2025
eb4d1c4
Merge branch 'develop' into fix_mxfp4_bugs
CharlieL7 Sep 25, 2025
3fa8a1b
Fix introduced bug
CharlieL7 Sep 26, 2025
ba1d4bc
Merge branch 'fix_mxfp4_bugs' of github.com:ROCmSoftwarePlatform/AMDM…
CharlieL7 Sep 26, 2025
7b11af3
Tidy fix
CharlieL7 Sep 26, 2025
335439b
initial
CharlieL7 Sep 26, 2025
787f539
Merge branch 'develop' into fix_mxfp4_bugs
causten Sep 27, 2025
2c2e3e7
Merge branch 'develop' into fix_mxfp4_bugs
causten Sep 28, 2025
5e04de7
reviews code style
CharlieL7 Sep 29, 2025
a29f4d9
Update rocMLIR and rocm
CharlieL7 Sep 29, 2025
473c85b
Avoid visit of fp4x2_type
CharlieL7 Sep 29, 2025
866dbf7
typo fix
CharlieL7 Sep 29, 2025
01e28de
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Sep 29, 2025
df2b72a
Merge branch 'fix_mxfp4_bugs' of github.com:ROCm/AMDMIGraphX into mli…
CharlieL7 Sep 29, 2025
c42d896
Remove quant_conv again
CharlieL7 Sep 29, 2025
b66be3b
Merge branch 'fix_mxfp4_bugs' of github.com:ROCm/AMDMIGraphX into mli…
CharlieL7 Sep 29, 2025
5eacd87
Update requirements rocmlir to latest branch
CharlieL7 Sep 29, 2025
2ecd0e3
Add Umang's changes
CharlieL7 Sep 30, 2025
8dd905b
Merge branch 'develop' into mlir_mxfp4_test
CharlieL7 Oct 7, 2025
a91f16b
In Progress
CharlieL7 Oct 10, 2025
25cf114
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Oct 13, 2025
8306abb
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 13, 2025
0f26532
tidy up
CharlieL7 Oct 14, 2025
d05b8f8
Enable simplify_qdq for unpack_fp4 only if >=MI350
CharlieL7 Oct 14, 2025
1299fd4
Add mxfp4 quant_dot verify test
CharlieL7 Oct 14, 2025
77bfb9b
Fix typos
CharlieL7 Oct 14, 2025
1eb70a5
Add to header
CharlieL7 Oct 14, 2025
57bd0eb
Use -> and using
CharlieL7 Oct 14, 2025
46c938c
More fixes
CharlieL7 Oct 14, 2025
7b5cd46
More fixes
CharlieL7 Oct 14, 2025
6dcb1b1
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 14, 2025
8334922
etc
CharlieL7 Oct 14, 2025
d58baef
add return
CharlieL7 Oct 14, 2025
ca0ef2d
Fix test
CharlieL7 Oct 14, 2025
81d4bd8
Update verify test tolerance
CharlieL7 Oct 14, 2025
ccf2f7a
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Oct 14, 2025
3b7cc58
Fix tests and change flag behavior
CharlieL7 Oct 15, 2025
8e97a82
Typo fix
CharlieL7 Oct 15, 2025
f4c30b9
Typo and include fixes
CharlieL7 Oct 15, 2025
6940e9f
formatting
CharlieL7 Oct 15, 2025
387114b
Fix mlir compilation for operand_size
CharlieL7 Oct 16, 2025
95e2ab0
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 16, 2025
991ca0d
Update changelog
CharlieL7 Oct 16, 2025
92aea67
Add dead_code_elim dependency
CharlieL7 Oct 17, 2025
4b2c6c0
Add back tolerance for verify test
CharlieL7 Oct 21, 2025
f6ea7e0
AIMIGRAPHX-193 use vector sizes
CharlieL7 Oct 29, 2025
a080654
Update mlir tests
CharlieL7 Oct 31, 2025
ad4a79c
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Oct 31, 2025
4bdd675
Remove fp4 check on quant_dot rocmlir compile
CharlieL7 Oct 31, 2025
6c56366
Formatting
CharlieL7 Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ inline namespace MIGRAPHX_INLINE_NS {

/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* Sets to fp4 tolerances if any fp4x2_type is found.
* Else sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
Expand All @@ -58,8 +59,17 @@ verify::tolerance get_tolerances(const program& p,
ins.get_shape().type() == shape::bf16_type);
});
});
bool has_fp4 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::fp4x2_type); });
});
migraphx::verify::tolerance result{};
if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
if(has_fp4)
{
result.rms_tol = 8e-1;
result.atol = 4e-1;
result.rtol = 4e-1;
}
else if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
Expand Down
2 changes: 2 additions & 0 deletions src/include/migraphx/simplify_qdq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ struct module;
*/
struct MIGRAPHX_EXPORT simplify_qdq
{
bool remove_qdq_only = false;
bool use_mx_quant = false;
std::string name() const { return "simplify_qdq"; }
void apply(module& m) const;
};
Expand Down
36 changes: 23 additions & 13 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,19 +618,29 @@ void add_int4_pack_unpack_pair(module& m)

void simplify_qdq::apply(module& m) const
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_find_mx_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, remove_qdq_pairs{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
if(remove_qdq_only)
{
match::find_matches(m, remove_qdq_pairs{});
}
else
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
if(use_mx_quant)
{
match::find_matches(m, match_find_mx_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}
match::find_matches(m, remove_qdq_pairs{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
6 changes: 6 additions & 0 deletions src/targets/gpu/device_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ bool gfx_has_bf16_intrinsics()
return not(starts_with(device_name, "gfx1030"));
}

bool gfx_has_mx_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return starts_with(device_name, "gfx9") and device_name >= "gfx950";
}

#if MIGRAPHX_USE_HIPBLASLT
// Archs that support hipBLASLt but are defaulted to use rocBLAS.
bool gfx_default_rocblas()
Expand Down
3 changes: 3 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/device_name.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ MIGRAPHX_GPU_EXPORT bool gfx_has_fp8ocp_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_bf16_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_mx_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_fp8fnuz_support();


#if MIGRAPHX_USE_HIPBLASLT
MIGRAPHX_GPU_EXPORT bool gfx_default_rocblas();
#endif
Expand Down
30 changes: 30 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ struct mlir_program

MlirType make_type(shape::type_t t) const
{
// non-computable type is not visit-able
if(t == shape::fp4x2_type)
{
return mlirFloat8E4M3FNTypeGet(ctx.get());
}
MlirType result;
shape::visit(t, [&](auto as) {
if(as.type_enum() == shape::float_type)
Expand Down Expand Up @@ -503,6 +508,17 @@ struct mlir_program
{
}

void set_operand_segement_sizes(int num_segments, const std::vector<int>& sizes)
{
MlirAttribute segment_sizes_attr =
mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data());
MlirNamedAttribute named_attr = mlirNamedAttributeGet(
mlirIdentifierGet(prog->ctx.get(),
mlirStringRefCreateFromCString("operandSegmentSizes")),
segment_sizes_attr);
mlirOperationStateAddAttributes(&op_state, 1, &named_attr);
}

mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{
auto attributes = prog->name_attributes(named_attrs);
Expand Down Expand Up @@ -647,6 +663,8 @@ struct mlir_program
return "migraphx.literal";
if(ins->name() == "unpack_int4")
return "migraphx.unpack";
if(ins->name() == "unpack_fp4")
return "migraphx.unpack";
if(ins->name() == "convolution_backwards")
return "migraphx.backwards_data_convolution";
if(is_reshape(ins->name()))
Expand Down Expand Up @@ -748,9 +766,20 @@ struct mlir_program
std::vector<MlirValue> inputs;
transform(
ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); });

if(ins->name() == "quant_dot" and
ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type and
ins->inputs().size() == 4)
{
// Specify operand segment sizes BEFORE creating the operation so MLIR sees it.
// Use the canonical MLIR attribute name 'operandSegmentSizes'.
const std::vector<int> seg_sizes = {1, 1, 1, 1};
ops.set_operand_segement_sizes(4, seg_sizes);
}
ops.add_operands(inputs);

auto outputs = insert(fbody, std::move(ops));

if(ins->name() != "@return")
{
assert(outputs.size() == 1);
Expand Down Expand Up @@ -1201,6 +1230,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx,
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}

auto co = mp.compile(solution);

co.expected_inputs = in_shapes;
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}),
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}),
simplify_qdq{},
simplify_qdq{.use_mx_quant=gfx_has_mx_intrinsics()},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
rewrite_rnn{},
Expand Down
3 changes: 3 additions & 0 deletions src/targets/ref/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/simplify_qdq.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand All @@ -51,6 +52,8 @@ std::vector<pass> target::get_passes(migraphx::context&, const compile_options&)
dead_code_elimination{},
rewrite_rnn{},
dead_code_elimination{},
simplify_qdq{.remove_qdq_only = true},
dead_code_elimination{},
auto_contiguous{},
dead_code_elimination{},
lowering{},
Expand Down
131 changes: 131 additions & 0 deletions test/verify/test_mxfp4_gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

migraphx::instruction_ref add_dyn_scale_calc(migraphx::module_ref m, migraphx::instruction_ref input, int block_axis, int block_size)

Check warning on line 31 in test/verify/test_mxfp4_gemm.cpp

View workflow job for this annotation

GitHub Actions / tidy

function 'add_dyn_scale_calc' can be made static or moved into an anonymous namespace to enforce internal linkage [misc-use-internal-linkage,-warnings-as-errors]
{
using migraphx::instruction_ref;
using migraphx::module_ref;
using migraphx::make_op;

// Code similar to that in parse_mxfixneruon
// make reduction axes for calculating block scales
// tmp_lens != input_lens if runt block is padded
instruction_ref tmp_in = input;
const auto input_lens = input->get_shape().lens();
auto tmp_lens = input_lens;
auto block_dim = tmp_lens.at(block_axis);
std::size_t block_padding =
std::ceil(double(block_dim) / double(block_size)) * block_size - block_dim;
// handle runt block by padding
if(block_padding != 0)
{
std::vector<std::size_t> pads_vec(2 * tmp_lens.size(), 0);
pads_vec.at(block_axis + tmp_lens.size()) = block_padding;
tmp_in = m->add_instruction(make_op("pad", {{"pads", pads_vec}}), tmp_in);
tmp_lens = tmp_in->get_shape().lens();
}
// reshape block dimension to {num_blocks, block_size}
std::size_t num_blocks = tmp_lens.at(block_axis) / std::size_t(block_size);
std::vector<std::size_t> reduct_dims = tmp_lens;
reduct_dims.at(block_axis) = block_size;
reduct_dims.insert(reduct_dims.begin() + block_axis, num_blocks);
instruction_ref reshape_ins = m->add_instruction(make_op("reshape", {{"dims", reduct_dims}}), tmp_in);

// dynamic quantization for MX types:
// V_k = fp32 vector input of block size k
// B_k = pow(2, floor(log2(reduce_max(abs(V_k))))) # largest power of 2 less than V
// X_k = block scale k = B_k / (largest power of 2 in fp4e2m1) = B_k / 4
auto abs_ins = m->add_instruction(make_op("abs"), reshape_ins);
auto reduce_max_ins =
m->add_instruction(make_op("reduce_max", {{"axes", {block_axis + 1}}}), abs_ins);
auto log2_ins = m->add_instruction(make_op("log2"), reduce_max_ins);
auto floor_ins = m->add_instruction(make_op("floor"), log2_ins);
auto lit_2_ins = m->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {2.f}});
auto broadcast_lit_2_ins = m->add_instruction(
make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}),
lit_2_ins);
auto pow_ins = m->add_instruction(make_op("pow"), broadcast_lit_2_ins, floor_ins);
auto lit_4_ins = m->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {4.f}});
auto broadcast_lit_4_ins = m->add_instruction(
make_op("multibroadcast", {{"out_lens", reduce_max_ins->get_shape().lens()}}),
lit_4_ins);
auto block_scales_ins = m->add_instruction(make_op("div"), pow_ins, broadcast_lit_4_ins);

// broadcast scales for use in quantizelinear
block_scales_ins = m->add_instruction(
make_op("multibroadcast", {{"out_lens", reduct_dims}}), block_scales_ins);
block_scales_ins =
m->add_instruction(make_op("reshape", {{"dims", tmp_lens}}), block_scales_ins);

// if padded runt block do slicing
if(tmp_lens != input_lens)
{
std::size_t slice_size = input_lens.at(block_axis);
block_scales_ins = m->add_instruction(
make_op("slice", {{"axes", {block_axis}}, {"starts", {0}}, {"ends", {slice_size}}}),
block_scales_ins);
}
return block_scales_ins;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should make this an op builder, so it can be reused. Probably outside of the scope of this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, probably should be an op builder.


/**
* Designed to be like the final GEMM of resnet50.
*/
struct test_mxfp4_gemm : verify_program<test_mxfp4_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::module_ref mmain = p.get_main_module();
auto input = mmain->add_parameter("x1", migraphx::shape{migraphx::shape::float_type, {1, 2048}});
auto input_scales = add_dyn_scale_calc(mmain, input, 1, 32);
input = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), input, input_scales);
input = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), input);
input = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), input);
input = mmain->add_instruction(migraphx::make_op("dequantizelinear"), input, input_scales);
auto weights = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1000, 2048}}, 2));
auto weight_scales = add_dyn_scale_calc(mmain, weights, 1, 32);
weights = mmain->add_instruction(migraphx::make_op("quantizelinear", {{"out_type", migraphx::shape::float_type}}), weights, weight_scales);
weights = mmain->add_instruction(migraphx::make_op("pack_fp4", {{"axis", 1}}), weights);
weights = mmain->add_instruction(migraphx::make_op("unpack_fp4", {{"axis", 1}}), weights);
weights = mmain->add_instruction(migraphx::make_op("dequantizelinear"), weights, weight_scales);
weights = mmain->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), weights);
auto bias = mmain->add_literal(migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 1000}}, 0));
auto dot = mmain->add_instruction(migraphx::make_op("dot"), input, weights);
auto bias_add = mmain->add_instruction(migraphx::make_op("add"), dot, bias);
mmain->add_return({bias_add});
return p;
}
std::string section() const { return "gemm"; }

std::size_t get_tolerance() const { return 4e5; };
};
Loading