11"""
22Run this script to recreate the original onnx model.
33Example usage:
4- python test_dangling_input_segment_ids.py out_model_path .onnx
4+ python test_dangling_input_segment_ids.py test_dangling_input_segment_ids .onnx
55"""
66
7- import os
87import sys
98
109import numpy as np
1110import onnx
12- from onnx import TensorProto , helper , numpy_helper
13-
14- DATA_DIR = os .path .join (os .path .dirname (os .path .realpath (__file__ )), "test_dangling_input_segment_ids" )
1511
1612
1713def order_repeated_field (repeated_proto , key_name , order ):
@@ -20,50 +16,104 @@ def order_repeated_field(repeated_proto, key_name, order):
2016
2117
2218def make_node (op_type , inputs , outputs , name = None , doc_string = None , domain = None , ** kwargs ):
23- node = helper .make_node (op_type , inputs , outputs , name , doc_string , domain , ** kwargs )
19+ node = onnx . helper .make_node (op_type , inputs , outputs , name , doc_string , domain , ** kwargs )
2420 if doc_string == "" :
2521 node .doc_string = ""
2622 order_repeated_field (node .attribute , "name" , kwargs .keys ())
2723 return node
2824
2925
3026def make_graph (* args , doc_string = None , ** kwargs ):
31- graph = helper .make_graph (* args , doc_string = doc_string , ** kwargs )
27+ graph = onnx . helper .make_graph (* args , doc_string = doc_string , ** kwargs )
3228 if doc_string == "" :
3329 graph .doc_string = ""
3430 return graph
3531
3632
37- model = helper .make_model (
38- opset_imports = [helper .make_operatorsetid ("" , 14 ), helper .make_operatorsetid ("com.microsoft" , 1 )],
33+ WORD_EMBED = np .array (
34+ [
35+ [0.31524479389190674 , 0.8928887248039246 , 0.5778571963310242 , 0.18401020765304565 ],
36+ [0.7879292368888855 , 0.6120311617851257 , 0.05390927195549011 , 0.4201936721801758 ],
37+ [0.6790688633918762 , 0.9186017513275146 , 0.0004020248888991773 , 0.976759135723114 ],
38+ [0.3765803277492523 , 0.973783552646637 , 0.6047161221504211 , 0.8288457989692688 ],
39+ [0.5747115015983582 , 0.6280761957168579 , 0.2855762839317322 , 0.5868333578109741 ],
40+ [0.750021755695343 , 0.8583138585090637 , 0.7550821900367737 , 0.698057234287262 ],
41+ [0.8644794225692749 , 0.3226810097694397 , 0.6707887649536133 , 0.4508739411830902 ],
42+ [0.38210275769233704 , 0.4108113646507263 , 0.401479572057724 , 0.31738394498825073 ],
43+ [0.6219193935394287 , 0.4302472770214081 , 0.9738020896911621 , 0.6778008937835693 ],
44+ [0.1985698938369751 , 0.42670100927352905 , 0.3433462381362915 , 0.7976388335227966 ],
45+ [0.8799982666969299 , 0.9038419723510742 , 0.6627197861671448 , 0.2702082693576813 ],
46+ [0.25236669182777405 , 0.8548979163169861 , 0.5277146697044373 , 0.8021610975265503 ],
47+ [0.57248854637146 , 0.7331425547599792 , 0.5190116167068481 , 0.7708839178085327 ],
48+ [0.5688579678535461 , 0.4657098650932312 , 0.3426889181137085 , 0.06820935010910034 ],
49+ [0.3779241740703583 , 0.07962607592344284 , 0.9828171133995056 , 0.18161284923553467 ],
50+ [0.8118587136268616 , 0.8749616742134094 , 0.6884132623672485 , 0.5694944262504578 ],
51+ [0.16097143292427063 , 0.46688002347946167 , 0.34517204761505127 , 0.22503995895385742 ],
52+ [0.5925118923187256 , 0.31226983666419983 , 0.9163055419921875 , 0.9096355438232422 ],
53+ [0.257118284702301 , 0.11089129745960236 , 0.19296273589134216 , 0.4995841681957245 ],
54+ [0.7285856604576111 , 0.20819443464279175 , 0.2480335533618927 , 0.8516718745231628 ],
55+ [0.4158487319946289 , 0.6166850924491882 , 0.23366613686084747 , 0.10196726024150848 ],
56+ [0.5158570408821106 , 0.47714099287986755 , 0.15267165005207062 , 0.6218062043190002 ],
57+ [0.5440101027488708 , 0.654137372970581 , 0.1445455402135849 , 0.7515278458595276 ],
58+ [0.22204914689064026 , 0.5193518400192261 , 0.7852960228919983 , 0.022330427542328835 ],
59+ [0.32436245679855347 , 0.8729223608970642 , 0.8447096347808838 , 0.5384405851364136 ],
60+ [0.8666082620620728 , 0.9498059749603271 , 0.8264070153236389 , 0.8541154265403748 ],
61+ [0.09874340146780014 , 0.651304304599762 , 0.703516960144043 , 0.6102408170700073 ],
62+ [0.7996152639389038 , 0.034571219235658646 , 0.7702387571334839 , 0.7317286133766174 ],
63+ [0.25969839096069336 , 0.25706928968429565 , 0.6323032975196838 , 0.3452974557876587 ],
64+ [0.796588659286499 , 0.4461462199687958 , 0.7827494144439697 , 0.9904717803001404 ],
65+ [0.30024832487106323 , 0.143005833029747 , 0.9013084173202515 , 0.5415593981742859 ],
66+ [0.9747403860092163 , 0.6366044282913208 , 0.9939129948616028 , 0.5460708141326904 ],
67+ ],
68+ dtype = np .float32 ,
69+ )
70+
71+ POS_EMBED = np .array (
72+ [
73+ [0.5264259576797485 , 0.13542790710926056 , 0.3557051718235016 , 0.026218567043542862 ],
74+ [0.16039517521858215 , 0.7456371784210205 , 0.030399689450860023 , 0.36654308438301086 ],
75+ [0.8623462319374084 , 0.6926777362823486 , 0.6909421682357788 , 0.18863679468631744 ],
76+ [0.4419042766094208 , 0.5815774202346802 , 0.9897516965866089 , 0.20390622317790985 ],
77+ [0.24773290753364563 , 0.2621730864048004 , 0.7501724362373352 , 0.4569753408432007 ],
78+ [0.056929439306259155 , 0.508516252040863 , 0.21196016669273376 , 0.7986042499542236 ],
79+ [0.29733139276504517 , 0.027606012299656868 , 0.5934324264526367 , 0.8438404202461243 ],
80+ [0.3810161352157593 , 0.7498583197593689 , 0.5111414790153503 , 0.5409517884254456 ],
81+ [0.9594343304634094 , 0.803960919380188 , 0.032323066145181656 , 0.7093872427940369 ],
82+ [0.46500149369239807 , 0.9475489258766174 , 0.22143273055553436 , 0.26707202196121216 ],
83+ [0.08147396147251129 , 0.42861881852149963 , 0.10901876538991928 , 0.6337867379188538 ],
84+ [0.8029632568359375 , 0.6968004703521729 , 0.7662113904953003 , 0.34245410561561584 ],
85+ [0.845851480960846 , 0.4287687838077545 , 0.824009895324707 , 0.6264961361885071 ],
86+ [0.14342305064201355 , 0.07838690280914307 , 0.018332643434405327 , 0.0667250007390976 ],
87+ [0.458583801984787 , 0.11334192007780075 , 0.0277833491563797 , 0.7548614740371704 ],
88+ [0.394850492477417 , 0.7469384670257568 , 0.45240482687950134 , 0.4500867426395416 ],
89+ ],
90+ dtype = np .float32 ,
91+ )
92+
93+ model = onnx .helper .make_model (
94+ opset_imports = [onnx .helper .make_operatorsetid ("" , 14 ), onnx .helper .make_operatorsetid ("com.microsoft" , 1 )],
3995 ir_version = 7 ,
4096 graph = make_graph (
4197 name = "embed_layernorm_graph" ,
4298 inputs = [
43- helper .make_tensor_value_info ("input_ids" , TensorProto .INT32 , shape = [1 , 4 ]),
44- helper .make_tensor_value_info ("segment_ids" , TensorProto .INT32 , shape = [1 , 4 ]),
99+ onnx . helper .make_tensor_value_info ("input_ids" , onnx . TensorProto .INT32 , shape = [1 , 4 ]),
100+ onnx . helper .make_tensor_value_info ("segment_ids" , onnx . TensorProto .INT32 , shape = [1 , 4 ]),
45101 ],
46102 outputs = [
47- helper .make_tensor_value_info ("layernorm_out" , TensorProto .FLOAT , shape = [1 , 4 , 4 ]),
48- helper .make_tensor_value_info ("mask_index_out" , TensorProto .INT32 , shape = [1 ]),
103+ onnx . helper .make_tensor_value_info ("layernorm_out" , onnx . TensorProto .FLOAT , shape = [1 , 4 , 4 ]),
104+ onnx . helper .make_tensor_value_info ("mask_index_out" , onnx . TensorProto .INT32 , shape = [1 ]),
49105 ],
50106 initializer = [
51- numpy_helper .from_array (
52- np .load (os .path .join (DATA_DIR , "const0_word_embed.npy" )).astype ("float32" ).reshape ([32 , 4 ]),
53- name = "word_embed" ,
54- ),
55- numpy_helper .from_array (
56- np .load (os .path .join (DATA_DIR , "const1_pos_embed.npy" )).astype ("float32" ).reshape ([16 , 4 ]),
57- name = "pos_embed" ,
58- ),
59- numpy_helper .from_array (
107+ onnx .numpy_helper .from_array (WORD_EMBED , name = "word_embed" ),
108+ onnx .numpy_helper .from_array (POS_EMBED , name = "pos_embed" ),
109+ onnx .numpy_helper .from_array (
60110 np .array (
61111 [0.6185135841369629 , 0.010364261455833912 , 0.5386272668838501 , 0.0030179566238075495 ],
62112 dtype = "float32" ,
63113 ),
64114 name = "gamma" ,
65115 ),
66- numpy_helper .from_array (
116+ onnx . numpy_helper .from_array (
67117 np .array (
68118 [0.9511938095092773 , 0.9054020047187805 , 0.7959669232368469 , 0.9152743220329285 ], dtype = "float32"
69119 ),
0 commit comments