diff --git a/tools/importers/onnx/lib/onnx_converters.py b/tools/importers/onnx/lib/onnx_converters.py index 98016e706..1bbe61797 100644 --- a/tools/importers/onnx/lib/onnx_converters.py +++ b/tools/importers/onnx/lib/onnx_converters.py @@ -811,7 +811,7 @@ def slice_tensor(self, node): if axis == 0: tensor = tensor[starts:ends,] elif axis == 1: - tensor =tensor[:,starts:ends,] + tensor = tensor[:,starts:ends,] elif axis == 2: tensor = tensor[:,:,starts:ends,] elif axis == 3: @@ -1279,11 +1279,14 @@ def get_attributes(self, attrs: Attributes): def get_output_shapes(self): # now compute the output size - channel = 0 if len(self.node.input_shapes) > 2: # third input is the output shape + input_shape = self.node.input_shapes[2][0] + if len(input_shape) == 1: + # normalize the shape + input_shape = (input_shape[0], 1) + return [(input_shape, self.get_order(input_shape))] return [self.node.input_shapes[2]] - if len(self.node.input_shapes) == 2: n, m1 = self.node.input_shapes[0][0] if attributes["transA"]: @@ -1303,7 +1306,6 @@ def get_output_shapes(self): else: raise Exception("Gemm operation is expecting two inputs, but we have {}".format(len(self.node.input_shapes))) - return [(result, self.get_order(result))] def get_weights(self): @@ -1324,7 +1326,7 @@ def get_weights(self): tensor_shape = tensor.shape tensor_len = len(tensor_shape) transpose = self.node.attributes["transB"] - if tensor_len == 2 : + if tensor_len == 2: if len(input_shape) == 3: tensor = self.reshape_3d_into_2d_tensor(input_shape, weights, transpose) tensor_shape = tensor.shape @@ -1336,12 +1338,12 @@ def get_weights(self): # from what you'd expect doing "weight * input", not "input * weight" so we may need to transpose here # again to account for this. input_shape = (np.product(input_shape), 1) - if tensor_shape[1] == input_shape[0]: # make sure we have m*n and n*k. + if tensor_shape[1] == input_shape[0]: # make sure we have m*n and n*k. pass elif tensor_shape[0] == input_shape[0]: # then the weights need to be transposed. tensor = tensor.T - self.add_tensor(weights[0], tensor) # re-register the transformed version. + self.add_tensor(weights[0], tensor) # re-register the transformed version. else: raise Exception("Cannot multiply matrices of incompatible shapes {} x {}".format(tensor_shape, input_shape)) @@ -1373,9 +1375,17 @@ def get_attributes(self, attrs: Attributes): if 'strides' in attrs: self.strides = attrs['strides'] attributes['stride'] = self.strides[0] + if len(self.strides) > 1: + for s in self.strides[1:]: + if s != self.strides[0]: + raise Exception("Multiple strides {} is not supported".format(self.strides)) if 'pads' in attrs: self.padding = attrs['pads'] attributes['padding'] = self.padding[0] + if len(self.padding) > 1: + for s in self.padding[1:]: + if s != self.padding[0]: + raise Exception("Multiple padding {} is not supported".format(self.padding)) if 'dilations' in attrs: self.dilations = attrs['dilations'] attributes['dilation'] = self.dilations[0]