From eb471ba97097c93f9506439eb478abafc0c80e8e Mon Sep 17 00:00:00 2001 From: Shiv Date: Mon, 13 Oct 2025 15:39:52 -0700 Subject: [PATCH] initial enablement --- src/fuse_pointwise.cpp | 46 +++++++++++++++++++++++---- src/include/migraphx/op/pointwise.hpp | 7 ++-- src/targets/gpu/target.cpp | 2 +- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/fuse_pointwise.cpp b/src/fuse_pointwise.cpp index f66e449dfc7..1456c1ceac2 100644 --- a/src/fuse_pointwise.cpp +++ b/src/fuse_pointwise.cpp @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins) if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name())) return get_scalar(ins->inputs().front()); const auto& s = ins->get_shape(); - if(s.elements() != 1 and not(s.scalar())) + if(s.dynamic() or (s.elements() != 1 and not(s.scalar()))) return {}; if(not ins->can_eval()) return {}; @@ -330,16 +330,20 @@ struct pointwise_reshape : rewrite_reshapes_base static std::string name() { return "pointwise"; } }; -struct pointwise_broadcast_pointwise +struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes { auto matcher() const { + auto pointwise = match::name("pointwise")(match::used_once()).bind("x"); auto broadcast_pointwise = - match::name("multibroadcast")( - match::used_once(), - match::args(match::name("pointwise")(match::used_once()).bind("x"))) + match::name("multibroadcast")(match::used_once(), match::args(pointwise)) .bind("broadcast"); - return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(), + match::nargs(2), + match::arg(1)(pointwise)) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()]( + match::any_of(broadcast_pointwise, dyn_broadcast_pointwise))); } void apply(module& m, const match::matcher_result& r) const @@ -359,11 +363,39 @@ struct pointwise_broadcast_pointwise } }; +// Use pointwise instruction input as reference for dynamic multibroadcast rather than +// the pointwise instruction itself +struct dyn_pointwise_broadcast : match::supports_dynamic_shapes +{ + auto matcher() const + { + auto broadcast_pointwise = + match::name("multibroadcast")( + match::nargs(2), + match::arg(0)(match::name("pointwise").bind("x"))) + .bind("broadcast"); + return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto broadcast_ins = r.instructions["broadcast"]; + auto x_ins = r.instructions["x"]; + + auto broadcast_inps = broadcast_ins->inputs(); + broadcast_inps[0] = x_ins->inputs().front(); + + m.replace_instruction(broadcast_ins, broadcast_ins->get_operator(), broadcast_inps); + } +}; + } // namespace static void rewrite_broadcasts(module_pass_manager& mpm) { - match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{}); + mpm.get_module().debug_print(); + match::find_matches( + mpm.get_module(), dyn_pointwise_broadcast{}, pointwise_broadcast_pointwise{}); mpm.run_pass(dead_code_elimination{}); } diff --git a/src/include/migraphx/op/pointwise.hpp b/src/include/migraphx/op/pointwise.hpp index 2f5116dc654..2372e71c7f7 100644 --- a/src/include/migraphx/op/pointwise.hpp +++ b/src/include/migraphx/op/pointwise.hpp @@ -49,11 +49,14 @@ struct pointwise MIGRAPHX_THROW("pointwise should have at least one input"); auto* pm = mods.front(); auto pnames = pm->get_parameter_names(); - check_shapes{inputs, *this}.has(pnames.size()).same_dims(); + check_shapes{inputs, *this, true}.has(pnames.size()).same_dims(); + + std::vector scalar_const_out_lens = + inputs.front().dynamic() ? std::vector{} : inputs.front().lens(); auto result = pm->compute_shapes( inputs, - {.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()}); + {.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens}); if(result.size() == 1) return result.front(); return shape{result}; diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index fbc540b20d9..4a8b3a986d8 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -228,7 +228,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}), dead_code_elimination{}, optimize_module{}, - enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), + fuse_pointwise_reduce{}, dead_code_elimination{}, #ifndef _WIN32 enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),