|
4 | 4 | python test_dangling_input_segment_ids.py out_model_path.onnx |
5 | 5 | """ |
6 | 6 |
|
7 | | -from onnx import helper, numpy_helper, TensorProto |
| 7 | +import os |
| 8 | +import sys |
8 | 9 |
|
9 | | -import onnx |
10 | 10 | import numpy as np |
11 | | -import sys |
12 | | -import os |
| 11 | +import onnx |
| 12 | +from onnx import TensorProto, helper, numpy_helper |
| 13 | + |
| 14 | +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids") |
13 | 15 |
|
14 | | -DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test_dangling_input_segment_ids') |
15 | 16 |
|
16 | 17 | def order_repeated_field(repeated_proto, key_name, order): |
17 | 18 | order = list(order) |
18 | 19 | repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) |
19 | 20 |
|
20 | | -def make_node( |
21 | | - op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs |
22 | | -): |
23 | | - node = helper.make_node( |
24 | | - op_type, inputs, outputs, name, doc_string, domain, **kwargs |
25 | | - ) |
| 21 | + |
| 22 | +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): |
| 23 | + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) |
26 | 24 | if doc_string == "": |
27 | 25 | node.doc_string = "" |
28 | 26 | order_repeated_field(node.attribute, "name", kwargs.keys()) |
29 | 27 | return node |
30 | 28 |
|
| 29 | + |
31 | 30 | def make_graph(*args, doc_string=None, **kwargs): |
32 | 31 | graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) |
33 | 32 | if doc_string == "": |
34 | 33 | graph.doc_string = "" |
35 | 34 | return graph |
36 | 35 |
|
| 36 | + |
37 | 37 | model = helper.make_model( |
38 | | - opset_imports=[helper.make_operatorsetid('', 14), helper.make_operatorsetid('com.microsoft', 1)], |
| 38 | + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], |
39 | 39 | ir_version=7, |
40 | 40 | graph=make_graph( |
41 | | - name='embed_layernorm_graph', |
42 | | - inputs=[helper.make_tensor_value_info('input_ids', TensorProto.INT32, shape=[1, 4]), helper.make_tensor_value_info('segment_ids', TensorProto.INT32, shape=[1, 4])], |
43 | | - outputs=[helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, shape=[1, 4, 4]), helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, shape=[1])], |
| 41 | + name="embed_layernorm_graph", |
| 42 | + inputs=[ |
| 43 | + helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]), |
| 44 | + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]), |
| 45 | + ], |
| 46 | + outputs=[ |
| 47 | + helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]), |
| 48 | + helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]), |
| 49 | + ], |
44 | 50 | initializer=[ |
45 | | - numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const0_word_embed.npy')).astype('float32').reshape([32, 4]), name='word_embed'), |
46 | | - numpy_helper.from_array(np.load(os.path.join(DATA_DIR, 'const1_pos_embed.npy')).astype('float32').reshape([16, 4]), name='pos_embed'), |
47 | | - numpy_helper.from_array(np.array([0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], dtype='float32'), name='gamma'), |
48 | | - numpy_helper.from_array(np.array([0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype='float32'), name='beta'), |
| 51 | + numpy_helper.from_array( |
| 52 | + np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]), |
| 53 | + name="word_embed", |
| 54 | + ), |
| 55 | + numpy_helper.from_array( |
| 56 | + np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]), |
| 57 | + name="pos_embed", |
| 58 | + ), |
| 59 | + numpy_helper.from_array( |
| 60 | + np.array( |
| 61 | + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], |
| 62 | + dtype="float32", |
| 63 | + ), |
| 64 | + name="gamma", |
| 65 | + ), |
| 66 | + numpy_helper.from_array( |
| 67 | + np.array( |
| 68 | + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" |
| 69 | + ), |
| 70 | + name="beta", |
| 71 | + ), |
| 72 | + ], |
| 73 | + nodes=[ |
| 74 | + make_node( |
| 75 | + "EmbedLayerNormalization", |
| 76 | + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], |
| 77 | + outputs=["layernorm_out", "mask_index_out"], |
| 78 | + domain="com.microsoft", |
| 79 | + ) |
49 | 80 | ], |
50 | | - nodes=[make_node('EmbedLayerNormalization', inputs=['input_ids', '', 'word_embed', 'pos_embed', '', 'gamma', 'beta'], outputs=['layernorm_out', 'mask_index_out'], domain='com.microsoft')], |
51 | 81 | ), |
52 | 82 | ) |
53 | 83 |
|
54 | | -if __name__ == '__main__' and len(sys.argv) == 2: |
| 84 | +if __name__ == "__main__" and len(sys.argv) == 2: |
55 | 85 | _, out_path = sys.argv |
56 | 86 | onnx.save(model, out_path) |
0 commit comments