Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions src/fuse_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ static literal get_scalar(instruction_ref ins)
if(contains({"contiguous", "broadcast", "multibroadcast"}, ins->name()))
return get_scalar(ins->inputs().front());
const auto& s = ins->get_shape();
if(s.elements() != 1 and not(s.scalar()))
if(s.dynamic() or (s.elements() != 1 and not(s.scalar())))
return {};
if(not ins->can_eval())
return {};
Expand Down Expand Up @@ -330,16 +330,20 @@ struct pointwise_reshape : rewrite_reshapes_base
static std::string name() { return "pointwise"; }
};

struct pointwise_broadcast_pointwise
struct pointwise_broadcast_pointwise : match::supports_dynamic_shapes
{
auto matcher() const
{
auto pointwise = match::name("pointwise")(match::used_once()).bind("x");
auto broadcast_pointwise =
match::name("multibroadcast")(
match::used_once(),
match::args(match::name("pointwise")(match::used_once()).bind("x")))
match::name("multibroadcast")(match::used_once(), match::args(pointwise))
.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise));
auto dyn_broadcast_pointwise = match::name("multibroadcast")(match::used_once(),
match::nargs(2),
match::arg(1)(pointwise))
.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](
match::any_of(broadcast_pointwise, dyn_broadcast_pointwise)));
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -359,11 +363,39 @@ struct pointwise_broadcast_pointwise
}
};

// Use pointwise instruction input as reference for dynamic multibroadcast rather than
// the pointwise instruction itself
struct dyn_pointwise_broadcast : match::supports_dynamic_shapes
{
auto matcher() const
{
auto broadcast_pointwise =
match::name("multibroadcast")(
match::nargs(2),
match::arg(0)(match::name("pointwise").bind("x")))
.bind("broadcast");
return match::name("pointwise")(match::any_of[match::inputs()](broadcast_pointwise));
}

void apply(module& m, const match::matcher_result& r) const
{
auto broadcast_ins = r.instructions["broadcast"];
auto x_ins = r.instructions["x"];

auto broadcast_inps = broadcast_ins->inputs();
broadcast_inps[0] = x_ins->inputs().front();

m.replace_instruction(broadcast_ins, broadcast_ins->get_operator(), broadcast_inps);
}
};

} // namespace

static void rewrite_broadcasts(module_pass_manager& mpm)
{
match::find_matches(mpm.get_module(), pointwise_broadcast_pointwise{});
mpm.get_module().debug_print();
match::find_matches(
mpm.get_module(), dyn_pointwise_broadcast{}, pointwise_broadcast_pointwise{});
mpm.run_pass(dead_code_elimination{});
}

Expand Down
7 changes: 5 additions & 2 deletions src/include/migraphx/op/pointwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ struct pointwise
MIGRAPHX_THROW("pointwise should have at least one input");
auto* pm = mods.front();
auto pnames = pm->get_parameter_names();
check_shapes{inputs, *this}.has(pnames.size()).same_dims();
check_shapes{inputs, *this, true}.has(pnames.size()).same_dims();

std::vector<std::size_t> scalar_const_out_lens =
inputs.front().dynamic() ? std::vector<std::size_t>{} : inputs.front().lens();

auto result = pm->compute_shapes(
inputs,
{.name = name(), .strict_type = true, .scalar_const_out_lens = inputs.front().lens()});
{.name = name(), .strict_type = true, .scalar_const_out_lens = scalar_const_out_lens});
if(result.size() == 1)
return result.front();
return shape{result};
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
enable_pass(mlir_attention_enabled(&ctx), fuse_attention{}),
dead_code_elimination{},
optimize_module{},
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}),
fuse_pointwise_reduce{},
dead_code_elimination{},
#ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
Expand Down
Loading