@@ -47,7 +47,7 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
4747 return " fuse_bn_into_conv" ;
4848 }
4949
50- bool modify_conv (Node* conv, Node* bn, Graph& graph) {
50+ bool modify_conv (Node* conv, Node* bn, Graph& graph, const bool is_conv ) {
5151 const auto & bn_inputs = bn->inputs ();
5252 const auto & conv_inputs = conv->inputs ();
5353
@@ -123,10 +123,9 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
123123 Node* unsqueeze = graph.create (kUnsqueeze , 1 );
124124 unsqueeze->insertAfter (scale);
125125 unsqueeze->addInput (scale->output ());
126- std::vector<int64_t > insert_dims;
127- for (int i = 1 ; i < conv_W.sizes ().size (); ++i) {
128- insert_dims.push_back (i);
129- }
126+ std::vector<int64_t > insert_dims (conv_W.sizes ().size ());
127+ std::iota (insert_dims.begin (), insert_dims.end (), 0 );
128+ insert_dims.erase (insert_dims.begin () + (is_conv ? 0 : 1 ));
130129 if (getOpsetVersion (graph) > 11 ) {
131130 Tensor shape_s_t ;
132131 shape_s_t .elem_type () = ONNX_NAMESPACE::TensorProto_DataType_INT64;
@@ -181,7 +180,8 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
181180 }
182181
183182 bool patternMatchPredicate (Node* n) override {
184- return CheckKind (n, kBatchNormalization , 0 , kConv ) &&
183+ return (CheckKind (n, kBatchNormalization , 0 , kConv ) ||
184+ CheckKind (n, kBatchNormalization , 0 , kConvTranspose )) &&
185185 GetValueFromAttrWithDefault (n, " training_mode" , (int64_t )0 ) == 0 &&
186186 n->input (0 )->uses ().size () == 1 && n->outputs ().size () == 1 &&
187187 IsConstantTensor (n, 1 ) && IsConstantTensor (n, 2 ) &&
@@ -190,10 +190,12 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
190190 }
191191 bool runTransform (Node* n, Graph& graph,
192192 NodeDestroyType& destroy_current) override {
193+ const bool is_conv = CheckKind (n, kBatchNormalization , 0 , kConv );
194+
193195 Node* bn = n;
194196 Node* conv = PrevNode (n, 0 );
195197 auto origInput = bn->inputs ()[0 ];
196- if (!modify_conv (conv, bn, graph)) {
198+ if (!modify_conv (conv, bn, graph, is_conv )) {
197199 destroy_current = NodeDestroyType::DestroyZero;
198200 return false ;
199201 }
0 commit comments