Skip to content

Commit a1198f2

Browse files
author
wenyuchi.wyc
committed
Support fuse add into ConvTranspose.
Signed-off-by: wenyuchi.wyc <[email protected]>
1 parent 74fdf9c commit a1198f2

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

onnxoptimizer/passes/fuse_add_bias_into_conv.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3232
std::string getPassName() const override {
3333
return "fuse_add_bias_into_conv";
3434
}
35-
bool patternMatchPredicate(Node *node) override {
35+
36+
inline bool matchConvAdd(Node *node) {
3637
return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConv &&
3738
node->inputs()[0]->node()->inputs().size() == 2;
3839
}
40+
41+
inline bool matchConvTransposeAdd(Node *node) {
42+
return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConvTranspose &&
43+
node->inputs()[0]->node()->inputs().size() == 2;
44+
}
45+
46+
bool patternMatchPredicate(Node *node) override {
47+
return matchConvAdd(node) || matchConvTransposeAdd(node);
48+
}
49+
3950
static Node *makeSqueezeOrUnsqueeze(Graph &graph, std::vector<int64_t> &axes,
4051
Value *input, Node *target_node,
4152
BuiltinSymbol k) {
@@ -61,6 +72,7 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
6172
NodeDestroyType &destroy_current) override {
6273
// due to current broadcasting's constraint, Conv has to be the first
6374
// operand
75+
const bool is_conv = matchConvAdd(n);
6476
destroy_current = NodeDestroyType::DestroyZero;
6577
auto orig_conv = n->inputs()[0];
6678
auto orig_bias = n->inputs()[1];
@@ -85,8 +97,8 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
8597
}
8698
// try to get feature M and rank from weight_shape
8799
if (weight_shape.size() > 0 && weight_shape[0].is_int) {
88-
ONNX_ASSERT(M == -1 || M == weight_shape[0].dim);
89-
M = weight_shape[0].dim;
100+
ONNX_ASSERT(M == -1 || M == weight_shape[0].dim || M == weight_shape[1].dim);
101+
M = is_conv ? weight_shape[0].dim : weight_shape[1].dim;
90102
ONNX_ASSERT(rank == -1 ||
91103
rank == static_cast<int64_t>(weight_shape.size()));
92104
rank = weight_shape.size();

0 commit comments

Comments
 (0)