Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4afcabd
Fix reshapes propogation for simplify_qdq and constant propagation
CharlieL7 Sep 25, 2025
9cffdf0
Fix more bugs and make tests
CharlieL7 Sep 25, 2025
4d57e7f
Fix reinterpret_cast of the possible r-value reference pointer to con…
CharlieL7 Sep 25, 2025
eb4d1c4
Merge branch 'develop' into fix_mxfp4_bugs
CharlieL7 Sep 25, 2025
3fa8a1b
Fix introduced bug
CharlieL7 Sep 26, 2025
ba1d4bc
Merge branch 'fix_mxfp4_bugs' of github.com:ROCmSoftwarePlatform/AMDM…
CharlieL7 Sep 26, 2025
7b11af3
Tidy fix
CharlieL7 Sep 26, 2025
335439b
initial
CharlieL7 Sep 26, 2025
787f539
Merge branch 'develop' into fix_mxfp4_bugs
causten Sep 27, 2025
2c2e3e7
Merge branch 'develop' into fix_mxfp4_bugs
causten Sep 28, 2025
5e04de7
reviews code style
CharlieL7 Sep 29, 2025
a29f4d9
Update rocMLIR and rocm
CharlieL7 Sep 29, 2025
473c85b
Avoid visit of fp4x2_type
CharlieL7 Sep 29, 2025
866dbf7
typo fix
CharlieL7 Sep 29, 2025
01e28de
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Sep 29, 2025
df2b72a
Merge branch 'fix_mxfp4_bugs' of github.com:ROCm/AMDMIGraphX into mli…
CharlieL7 Sep 29, 2025
c42d896
Remove quant_conv again
CharlieL7 Sep 29, 2025
b66be3b
Merge branch 'fix_mxfp4_bugs' of github.com:ROCm/AMDMIGraphX into mli…
CharlieL7 Sep 29, 2025
5eacd87
Update requirements rocmlir to latest branch
CharlieL7 Sep 29, 2025
2ecd0e3
Add Umang's changes
CharlieL7 Sep 30, 2025
8dd905b
Merge branch 'develop' into mlir_mxfp4_test
CharlieL7 Oct 7, 2025
a91f16b
In Progress
CharlieL7 Oct 10, 2025
25cf114
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Oct 13, 2025
8306abb
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 13, 2025
0f26532
tidy up
CharlieL7 Oct 14, 2025
d05b8f8
Enable simplify_qdq for unpack_fp4 only if >=MI350
CharlieL7 Oct 14, 2025
1299fd4
Add mxfp4 quant_dot verify test
CharlieL7 Oct 14, 2025
77bfb9b
Fix typos
CharlieL7 Oct 14, 2025
1eb70a5
Add to header
CharlieL7 Oct 14, 2025
57bd0eb
Use -> and using
CharlieL7 Oct 14, 2025
46c938c
More fixes
CharlieL7 Oct 14, 2025
7b5cd46
More fixes
CharlieL7 Oct 14, 2025
6dcb1b1
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 14, 2025
8334922
etc
CharlieL7 Oct 14, 2025
d58baef
add return
CharlieL7 Oct 14, 2025
ca0ef2d
Fix test
CharlieL7 Oct 14, 2025
81d4bd8
Update verify test tolerance
CharlieL7 Oct 14, 2025
ccf2f7a
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mlir_mxfp4…
CharlieL7 Oct 14, 2025
3b7cc58
Fix tests and change flag behavior
CharlieL7 Oct 15, 2025
8e97a82
Typo fix
CharlieL7 Oct 15, 2025
f4c30b9
Typo and include fixes
CharlieL7 Oct 15, 2025
6940e9f
formatting
CharlieL7 Oct 15, 2025
387114b
Fix mlir compilation for operand_size
CharlieL7 Oct 16, 2025
95e2ab0
Merge branch 'mlir_mxfp4_test' of github.com:ROCm/AMDMIGraphX into ml…
CharlieL7 Oct 16, 2025
991ca0d
Update changelog
CharlieL7 Oct 16, 2025
92aea67
Add dead_code_elim dependency
CharlieL7 Oct 17, 2025
4b2c6c0
Add back tolerance for verify test
CharlieL7 Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
Full documentation for MIGraphX is available at
[https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/).

## Develop Branch

### Added
* Added MXFP4 support for Quark and Brevitas quantized models (GEMMs only) (#4343)


## MIGraphX 2.14 for ROCm 7.1.0

Expand Down
3 changes: 2 additions & 1 deletion src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ struct verify : command<verify>
ap(bisect, {"-b", "--bisect"}, ap.help("Bisect program and verify"), ap.set_value(true));
ap(vo.ref_use_double,
{"--ref-use-double"},
ap.help("Convert floating point values to double on ref"),
ap.help(
"Convert floating point values to double on ref. Also removes Q/DQ pairs on ref."),
ap.set_value(true));
ap(vo.compiled_model, {"--compiled-model", "-c"}, ap.help("Compiled model to use"));
}
Expand Down
19 changes: 16 additions & 3 deletions src/driver/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include <migraphx/register_target.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/verify_args.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <utility>

namespace migraphx {
Expand All @@ -43,7 +45,8 @@ inline namespace MIGRAPHX_INLINE_NS {

/**
* Gives tolerances based on user input (`rms_tol`, `atol`, `rtol` parameters) and defaults.
* Sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction in found in the
* Sets to fp4 tolerances if any fp4x2_type is found.
* Else sets to fp16 tolerances if `quantize` input is fp16 or any fp16 instruction is found in the
* model.
*/
verify::tolerance get_tolerances(const program& p,
Expand All @@ -58,8 +61,17 @@ verify::tolerance get_tolerances(const program& p,
ins.get_shape().type() == shape::bf16_type);
});
});
bool has_fp4 = any_of(p.get_modules(), [](auto&& m) {
return any_of(*m, [](auto&& ins) { return (ins.get_shape().type() == shape::fp4x2_type); });
});
migraphx::verify::tolerance result{};
if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
if(has_fp4)
{
result.rms_tol = 8e-1;
result.atol = 4e-1;
result.rtol = 4e-1;
}
else if(has_16bit or vo.quantize == precision::fp16 or vo.quantize == precision::bf16)
{
result.rms_tol = 8e-2;
result.atol = 4e-2;
Expand Down Expand Up @@ -87,7 +99,8 @@ static std::vector<argument> run_ref(program p,
{
if(vo.ref_use_double)
{
run_passes(p, {fp_to_double{}});
run_passes(
p, {fp_to_double{}, simplify_qdq{.remove_qdq_only = true}, dead_code_elimination{}});
}
p.compile(migraphx::make_target("ref"), options);
auto out = p.eval(inputs);
Expand Down
2 changes: 1 addition & 1 deletion src/driver/verify_options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct verify_options
precision quantize = precision::fp32;

/**
* Converts floating point values to double on the ref target.
* Converts floating point values to double on the ref target. Also removes Q/DQ pairs on ref.
*/
bool ref_use_double = false;

Expand Down
4 changes: 3 additions & 1 deletion src/include/migraphx/simplify_qdq.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -38,6 +38,8 @@ struct module;
*/
struct MIGRAPHX_EXPORT simplify_qdq
{
bool remove_qdq_only = false;
bool use_mx_quant = false;
std::string name() const { return "simplify_qdq"; }
void apply(module& m) const;
};
Expand Down
36 changes: 23 additions & 13 deletions src/simplify_qdq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,19 +618,29 @@ void add_int4_pack_unpack_pair(module& m)

void simplify_qdq::apply(module& m) const
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_find_mx_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, remove_qdq_pairs{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
if(remove_qdq_only)
{
match::find_matches(m, remove_qdq_pairs{});
}
else
{
// first step: add pack/unpack pair between qdq for int4 weights
add_int4_pack_unpack_pair(m);
match::find_matches(m, match_find_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
if(use_mx_quant)
{
match::find_matches(m, match_find_mx_quantizable_ops{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
}
match::find_matches(m, remove_qdq_pairs{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_qlinear_reused{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
match::find_matches(m, match_concat_qlinear{});
migraphx::run_passes(m, {migraphx::dead_code_elimination{}});
remove_zero_point(m);
}
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
6 changes: 6 additions & 0 deletions src/targets/gpu/device_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ bool gfx_has_bf16_intrinsics()
return not(starts_with(device_name, "gfx1030"));
}

bool gfx_has_mx_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return starts_with(device_name, "gfx9") and device_name >= "gfx950";
}

#if MIGRAPHX_USE_HIPBLASLT
// Archs that support hipBLASLt but are defaulted to use rocBLAS.
bool gfx_default_rocblas()
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/device_name.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ MIGRAPHX_GPU_EXPORT bool gfx_has_fp8ocp_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_bf16_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_mx_intrinsics();

MIGRAPHX_GPU_EXPORT bool gfx_has_fp8fnuz_support();

#if MIGRAPHX_USE_HIPBLASLT
Expand Down
39 changes: 39 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ struct mlir_program

MlirType make_type(shape::type_t t) const
{
// non-computable type is not visit-able
if(t == shape::fp4x2_type)
{
return mlirFloat8E4M3FNTypeGet(ctx.get());
}
MlirType result;
shape::visit(t, [&](auto as) {
if(as.type_enum() == shape::float_type)
Expand Down Expand Up @@ -503,6 +508,17 @@ struct mlir_program
{
}

void set_operand_segment_sizes(int num_segments, const std::vector<int>& sizes)
{
MlirAttribute segment_sizes_attr =
mlirDenseI32ArrayGet(prog->ctx.get(), num_segments, sizes.data());
MlirNamedAttribute named_attr = mlirNamedAttributeGet(
mlirIdentifierGet(prog->ctx.get(),
mlirStringRefCreateFromCString("operandSegmentSizes")),
segment_sizes_attr);
mlirOperationStateAddAttributes(&op_state, 1, &named_attr);
}

mlir_operation_state& add_attributes(const std::vector<named_attribute_t>& named_attrs)
{
auto attributes = prog->name_attributes(named_attrs);
Expand Down Expand Up @@ -647,6 +663,8 @@ struct mlir_program
return "migraphx.literal";
if(ins->name() == "unpack_int4")
return "migraphx.unpack";
if(ins->name() == "unpack_fp4")
return "migraphx.unpack";
if(ins->name() == "convolution_backwards")
return "migraphx.backwards_data_convolution";
if(is_reshape(ins->name()))
Expand Down Expand Up @@ -748,9 +766,29 @@ struct mlir_program
std::vector<MlirValue> inputs;
transform(
ins->inputs(), std::back_inserter(inputs), [&](auto i) { return ins_map.at(i); });
if(ins->name() == "dot") {
const std::vector<int> seg_sizes = {1, 1, 0, 0};
ops.set_operand_segment_sizes(4, seg_sizes);
}
else if(ins->name() == "quant_dot")
{
if(ins->inputs().size() == 4 and ins->inputs().front()->get_shape().type() == shape::fp8e4m3fn_type)
{
// Specify operand segment sizes BEFORE creating the operation so MLIR sees it.
// Use the canonical MLIR attribute name 'operandSegmentSizes'.
const std::vector<int> seg_sizes = {1, 1, 1, 1};
ops.set_operand_segment_sizes(4, seg_sizes);
}
else if(ins->inputs().size() == 2)
{
const std::vector<int> seg_sizes = {1, 1, 0, 0};
ops.set_operand_segment_sizes(4, seg_sizes);
}
}
ops.add_operands(inputs);

auto outputs = insert(fbody, std::move(ops));

if(ins->name() != "@return")
{
assert(outputs.size() == 1);
Expand Down Expand Up @@ -1201,6 +1239,7 @@ mlir_code_object compile_mlir(const context& migraphx_ctx,
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}

auto co = mp.compile(solution);

co.expected_inputs = in_shapes;
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), fp8_ocp_to_fnuz{}),
enable_pass(not gpu::gfx_has_fp8ocp_intrinsics() and gpu::gfx_has_fp8fnuz_intrinsics(), dead_code_elimination{}),
simplify_qdq{},
simplify_qdq{.use_mx_quant=gpu::gfx_has_mx_intrinsics()},
enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{},
rewrite_rnn{},
Expand Down
2 changes: 1 addition & 1 deletion src/targets/ref/target.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
6 changes: 4 additions & 2 deletions test/simplify_qdq_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ static bool is_dot(const migraphx::instruction& ins) { return ins.name() == "dot

static void run_pass(migraphx::module& m)
{
run_passes(m, {migraphx::simplify_qdq{}, migraphx::dead_code_elimination{}});
run_passes(m,
{migraphx::simplify_qdq{.remove_qdq_only = false, .use_mx_quant = true},
migraphx::dead_code_elimination{}});
}

static void run_cse(migraphx::module& m)
Expand Down Expand Up @@ -1456,7 +1458,7 @@ TEST_CASE(dot_reused)
auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 1, sh.lens());
auto d2 = add_quantize_op(m2, "dequantizelinear", dot2, out_scale2);
auto d3 = add_quantize_op(m2, "dequantizelinear", q3, q3->inputs()[1]);
auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3);
auto add2 = m2.add_instruction(migraphx::make_op("add"), d2, d3);
m2.add_return({add2});
}

Expand Down
Loading
Loading