Skip to content

Commit 2891999

Browse files
committed
Lint
1 parent 455b786 commit 2891999

File tree

1 file changed

+51
-21
lines changed

1 file changed

+51
-21
lines changed

onnxruntime/test/testdata/test_dangling_input_segment_ids.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,83 @@
44
python test_dangling_input_segment_ids.py out_model_path.onnx
55
"""
66

7-
from onnx import helper, numpy_helper, TensorProto
7+
import os
8+
import sys
89

9-
import onnx
1010
import numpy as np
11-
import sys
12-
import os
11+
import onnx
12+
from onnx import TensorProto, helper, numpy_helper
13+
14+
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids")
1315

14-
DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dangling_input_segment_ids')
1516

1617
def order_repeated_field(repeated_proto, key_name, order):
1718
order = list(order)
1819
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))
1920

20-
def make_node(
21-
op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs
22-
):
23-
node = helper.make_node(
24-
op_type, inputs, outputs, name, doc_string, domain, **kwargs
25-
)
21+
22+
def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
23+
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
2624
if doc_string == "":
2725
node.doc_string = ""
2826
order_repeated_field(node.attribute, "name", kwargs.keys())
2927
return node
3028

29+
3130
def make_graph(*args, doc_string=None, **kwargs):
3231
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
3332
if doc_string == "":
3433
graph.doc_string = ""
3534
return graph
3635

36+
3737
model = helper.make_model(
38-
opset_imports=[helper.make_operatorsetid('', 14), helper.make_operatorsetid('com.microsoft', 1)],
38+
opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)],
3939
ir_version=7,
4040
graph=make_graph(
41-
name='embed_layernorm_graph',
42-
inputs=[helper.make_tensor_value_info('input_ids', TensorProto.INT32, shape=[1, 4]), helper.make_tensor_value_info('segment_ids', TensorProto.INT32, shape=[1, 4])],
43-
outputs=[helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, shape=[1, 4, 4]), helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, shape=[1])],
41+
name="embed_layernorm_graph",
42+
inputs=[
43+
helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]),
44+
helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]),
45+
],
46+
outputs=[
47+
helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]),
48+
helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]),
49+
],
4450
initializer=[
45-
numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const0_word_embed.npy')).astype('float32').reshape([32, 4]), name='word_embed'),
46-
numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const1_pos_embed.npy')).astype('float32').reshape([16, 4]), name='pos_embed'),
47-
numpy_helper.from_array(np.array([0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], dtype='float32'), name='gamma'),
48-
numpy_helper.from_array(np.array([0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype='float32'), name='beta'),
51+
numpy_helper.from_array(
52+
np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]),
53+
name="word_embed",
54+
),
55+
numpy_helper.from_array(
56+
np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]),
57+
name="pos_embed",
58+
),
59+
numpy_helper.from_array(
60+
np.array(
61+
[0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495],
62+
dtype="float32",
63+
),
64+
name="gamma",
65+
),
66+
numpy_helper.from_array(
67+
np.array(
68+
[0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32"
69+
),
70+
name="beta",
71+
),
72+
],
73+
nodes=[
74+
make_node(
75+
"EmbedLayerNormalization",
76+
inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"],
77+
outputs=["layernorm_out", "mask_index_out"],
78+
domain="com.microsoft",
79+
)
4980
],
50-
nodes=[make_node('EmbedLayerNormalization', inputs=['input_ids', '', 'word_embed', 'pos_embed', '', 'gamma', 'beta'], outputs=['layernorm_out', 'mask_index_out'], domain='com.microsoft')],
5181
),
5282
)
5383

54-
if __name__ == '__main__' and len(sys.argv) == 2:
84+
if __name__ == "__main__" and len(sys.argv) == 2:
5585
_, out_path = sys.argv
5686
onnx.save(model, out_path)

0 commit comments

Comments
 (0)