@@ -32,10 +32,21 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
3232 std::string getPassName () const override {
3333 return " fuse_add_bias_into_conv" ;
3434 }
35- bool patternMatchPredicate (Node *node) override {
35+
36+ inline bool matchConvAdd (Node *node) {
3637 return node->kind () == kAdd && node->inputs ()[0 ]->node ()->kind () == kConv &&
3738 node->inputs ()[0 ]->node ()->inputs ().size () == 2 ;
3839 }
40+
41+ inline bool matchConvTransposeAdd (Node *node) {
42+ return node->kind () == kAdd && node->inputs ()[0 ]->node ()->kind () == kConvTranspose &&
43+ node->inputs ()[0 ]->node ()->inputs ().size () == 2 ;
44+ }
45+
46+ bool patternMatchPredicate (Node *node) override {
47+ return matchConvAdd (node) || matchConvTransposeAdd (node);
48+ }
49+
3950 static Node *makeSqueezeOrUnsqueeze (Graph &graph, std::vector<int64_t > &axes,
4051 Value *input, Node *target_node,
4152 BuiltinSymbol k) {
@@ -61,6 +72,7 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
6172 NodeDestroyType &destroy_current) override {
6273 // due to current broadcasting's constraint, Conv has to be the first
6374 // operand
75+ const bool is_conv = matchConvAdd (n);
6476 destroy_current = NodeDestroyType::DestroyZero;
6577 auto orig_conv = n->inputs ()[0 ];
6678 auto orig_bias = n->inputs ()[1 ];
@@ -85,8 +97,8 @@ struct FuseAddBiasIntoConv final : public PredicateBasedPass {
8597 }
8698 // try to get feature M and rank from weight_shape
8799 if (weight_shape.size () > 0 && weight_shape[0 ].is_int ) {
88- ONNX_ASSERT (M == -1 || M == weight_shape[0 ].dim );
89- M = weight_shape[0 ].dim ;
100+ ONNX_ASSERT (M == -1 || M == weight_shape[0 ].dim || M == weight_shape[ 1 ]. dim );
101+ M = is_conv ? weight_shape[0 ]. dim : weight_shape[ 1 ].dim ;
90102 ONNX_ASSERT (rank == -1 ||
91103 rank == static_cast <int64_t >(weight_shape.size ()));
92104 rank = weight_shape.size ();
0 commit comments