Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions tools/importers/onnx/lib/onnx_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def slice_tensor(self, node):
if axis == 0:
tensor = tensor[starts:ends,]
elif axis == 1:
tensor =tensor[:,starts:ends,]
tensor = tensor[:,starts:ends,]
elif axis == 2:
tensor = tensor[:,:,starts:ends,]
elif axis == 3:
Expand Down Expand Up @@ -1279,11 +1279,14 @@ def get_attributes(self, attrs: Attributes):

def get_output_shapes(self):
# now compute the output size
channel = 0
if len(self.node.input_shapes) > 2:
# third input is the output shape
input_shape = self.node.input_shapes[2][0]
if len(input_shape) == 1:
# normalize the shape
input_shape = (input_shape[0], 1)
return [(input_shape, self.get_order(input_shape))]
return [self.node.input_shapes[2]]

if len(self.node.input_shapes) == 2:
n, m1 = self.node.input_shapes[0][0]
if attributes["transA"]:
Expand All @@ -1303,7 +1306,6 @@ def get_output_shapes(self):

else:
raise Exception("Gemm operation is expecting two inputs, but we have {}".format(len(self.node.input_shapes)))

return [(result, self.get_order(result))]

def get_weights(self):
Expand All @@ -1324,7 +1326,7 @@ def get_weights(self):
tensor_shape = tensor.shape
tensor_len = len(tensor_shape)
transpose = self.node.attributes["transB"]
if tensor_len == 2 :
if tensor_len == 2:
if len(input_shape) == 3:
tensor = self.reshape_3d_into_2d_tensor(input_shape, weights, transpose)
tensor_shape = tensor.shape
Expand All @@ -1336,12 +1338,12 @@ def get_weights(self):
# from what you'd expect doing "weight * input", not "input * weight" so we may need to transpose here
# again to account for this.
input_shape = (np.product(input_shape), 1)
if tensor_shape[1] == input_shape[0]: # make sure we have m*n and n*k.
if tensor_shape[1] == input_shape[0]: # make sure we have m*n and n*k.
pass
elif tensor_shape[0] == input_shape[0]:
# then the weights need to be transposed.
tensor = tensor.T
self.add_tensor(weights[0], tensor) # re-register the transformed version.
self.add_tensor(weights[0], tensor) # re-register the transformed version.
else:
raise Exception("Cannot multiply matrices of incompatible shapes {} x {}".format(tensor_shape, input_shape))

Expand Down Expand Up @@ -1373,9 +1375,17 @@ def get_attributes(self, attrs: Attributes):
if 'strides' in attrs:
self.strides = attrs['strides']
attributes['stride'] = self.strides[0]
if len(self.strides) > 1:
for s in self.strides[1:]:
if s != self.strides[0]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about "all(y == x[0] for y in x[1:])"

raise Exception("Multiple strides {} is not supported".format(self.strides))
if 'pads' in attrs:
self.padding = attrs['pads']
attributes['padding'] = self.padding[0]
if len(self.padding) > 1:
for s in self.padding[1:]:
if s != self.padding[0]:
raise Exception("Multiple padding {} is not supported".format(self.padding))
if 'dilations' in attrs:
self.dilations = attrs['dilations']
attributes['dilation'] = self.dilations[0]
Expand Down