8
8
9
9
class MultiSegmentPackerTest (TestCase ):
10
10
def test_trim_single_input_ints (self ):
11
+ # right padding
11
12
input_data = np .arange (3 , 10 )
12
13
packer = MultiSegmentPacker (
13
14
sequence_length = 8 , start_value = 1 , end_value = 2
@@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
16
17
self .assertAllEqual (token_ids , [1 , 3 , 4 , 5 , 6 , 7 , 8 , 2 ])
17
18
self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
18
19
20
+ # left padding
21
+ input_data = np .arange (3 , 10 )
22
+ packer = MultiSegmentPacker (
23
+ sequence_length = 8 ,
24
+ start_value = 1 ,
25
+ end_value = 2 ,
26
+ padding_side = "left" ,
27
+ )
28
+ token_ids , segment_ids = packer (input_data )
29
+ self .assertAllEqual (token_ids , [1 , 3 , 4 , 5 , 6 , 7 , 8 , 2 ])
30
+ self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
31
+
19
32
def test_trim_single_input_strings (self ):
33
+ # right padding
20
34
input_data = ["a" , "b" , "c" , "d" ]
21
35
packer = MultiSegmentPacker (
22
36
sequence_length = 5 , start_value = "[CLS]" , end_value = "[SEP]"
@@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
25
39
self .assertAllEqual (token_ids , ["[CLS]" , "a" , "b" , "c" , "[SEP]" ])
26
40
self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 ])
27
41
42
+ # left padding
43
+ packer = MultiSegmentPacker (
44
+ sequence_length = 5 ,
45
+ start_value = "[CLS]" ,
46
+ end_value = "[SEP]" ,
47
+ padding_side = "left" ,
48
+ )
49
+ token_ids , segment_ids = packer (input_data )
50
+ self .assertAllEqual (token_ids , ["[CLS]" , "a" , "b" , "c" , "[SEP]" ])
51
+ self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 ])
52
+
28
53
def test_trim_multiple_inputs_round_robin (self ):
54
+ # right padding
29
55
seq1 = ["a" , "b" , "c" ]
30
56
seq2 = ["x" , "y" , "z" ]
31
57
packer = MultiSegmentPacker (
@@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
40
66
)
41
67
self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 1 , 1 , 1 ])
42
68
69
+ # left padding
70
+ packer = MultiSegmentPacker (
71
+ sequence_length = 7 ,
72
+ start_value = "[CLS]" ,
73
+ end_value = "[SEP]" ,
74
+ truncate = "round_robin" ,
75
+ padding_side = "left" ,
76
+ )
77
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
78
+ self .assertAllEqual (
79
+ token_ids , ["[CLS]" , "a" , "b" , "[SEP]" , "x" , "y" , "[SEP]" ]
80
+ )
81
+ self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 1 , 1 , 1 ])
82
+
43
83
def test_trim_multiple_inputs_waterfall (self ):
84
+ # right padding
44
85
seq1 = ["a" , "b" , "c" ]
45
86
seq2 = ["x" , "y" , "z" ]
46
87
packer = MultiSegmentPacker (
@@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
55
96
)
56
97
self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 , 1 , 1 ])
57
98
99
+ # left padding
100
+ packer = MultiSegmentPacker (
101
+ sequence_length = 7 ,
102
+ start_value = "[CLS]" ,
103
+ end_value = "[SEP]" ,
104
+ truncate = "waterfall" ,
105
+ padding_side = "left" ,
106
+ )
107
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
108
+ self .assertAllEqual (
109
+ token_ids , ["[CLS]" , "a" , "b" , "c" , "[SEP]" , "x" , "[SEP]" ]
110
+ )
111
+ self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 0 , 1 , 1 ])
112
+
58
113
def test_trim_batched_inputs_round_robin (self ):
114
+ # right padding
59
115
seq1 = [["a" , "b" , "c" ], ["a" , "b" , "c" ]]
60
116
seq2 = [["x" , "y" , "z" ], ["x" , "y" , "z" ]]
61
117
packer = MultiSegmentPacker (
@@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
80
136
],
81
137
)
82
138
139
+ # left padding
140
+ packer = MultiSegmentPacker (
141
+ sequence_length = 7 ,
142
+ start_value = "[CLS]" ,
143
+ end_value = "[SEP]" ,
144
+ truncate = "round_robin" ,
145
+ padding_side = "left" ,
146
+ )
147
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
148
+ self .assertAllEqual (
149
+ token_ids ,
150
+ [
151
+ ["[CLS]" , "a" , "b" , "[SEP]" , "x" , "y" , "[SEP]" ],
152
+ ["[CLS]" , "a" , "b" , "[SEP]" , "x" , "y" , "[SEP]" ],
153
+ ],
154
+ )
155
+ self .assertAllEqual (
156
+ segment_ids ,
157
+ [
158
+ [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
159
+ [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
160
+ ],
161
+ )
162
+
83
163
def test_trim_batched_inputs_waterfall (self ):
164
+ # right padding
84
165
seq1 = [["a" , "b" , "c" ], ["a" , "b" ]]
85
166
seq2 = [["x" , "y" , "z" ], ["x" , "y" , "z" ]]
86
167
packer = MultiSegmentPacker (
@@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
105
186
],
106
187
)
107
188
189
+ # left padding
190
+ packer = MultiSegmentPacker (
191
+ sequence_length = 7 ,
192
+ start_value = "[CLS]" ,
193
+ end_value = "[SEP]" ,
194
+ truncate = "waterfall" ,
195
+ padding_side = "left" ,
196
+ )
197
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
198
+ self .assertAllEqual (
199
+ token_ids ,
200
+ [
201
+ ["[CLS]" , "a" , "b" , "c" , "[SEP]" , "x" , "[SEP]" ],
202
+ ["[CLS]" , "a" , "b" , "[SEP]" , "x" , "y" , "[SEP]" ],
203
+ ],
204
+ )
205
+ self .assertAllEqual (
206
+ segment_ids ,
207
+ [
208
+ [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
209
+ [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
210
+ ],
211
+ )
212
+
108
213
def test_pad_inputs (self ):
214
+ # right padding
109
215
seq1 = ["a" ]
110
216
seq2 = ["x" ]
111
217
packer = MultiSegmentPacker (
@@ -118,7 +224,23 @@ def test_pad_inputs(self):
118
224
)
119
225
self .assertAllEqual (segment_ids , [0 , 0 , 0 , 1 , 1 , 0 ])
120
226
227
+ # left padding
228
+ packer = MultiSegmentPacker (
229
+ 6 ,
230
+ start_value = "[CLS]" ,
231
+ end_value = "[SEP]" ,
232
+ pad_value = "[PAD]" ,
233
+ padding_side = "left" ,
234
+ )
235
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
236
+ self .assertAllEqual (
237
+ token_ids ,
238
+ ["[PAD]" , "[CLS]" , "a" , "[SEP]" , "x" , "[SEP]" ],
239
+ )
240
+ self .assertAllEqual (segment_ids , [0 , 0 , 0 , 0 , 1 , 1 ])
241
+
121
242
def test_pad_batched_inputs (self ):
243
+ # right padding
122
244
seq1 = [["a" ], ["a" ]]
123
245
seq2 = [["x" ], ["x" , "y" ]]
124
246
packer = MultiSegmentPacker (
@@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
143
265
],
144
266
)
145
267
268
+ # left padding
269
+ packer = MultiSegmentPacker (
270
+ sequence_length = 7 ,
271
+ start_value = "[CLS]" ,
272
+ end_value = "[SEP]" ,
273
+ pad_value = "[PAD]" ,
274
+ padding_side = "left" ,
275
+ )
276
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
277
+ self .assertAllEqual (
278
+ token_ids ,
279
+ [
280
+ ["[PAD]" , "[PAD]" , "[CLS]" , "a" , "[SEP]" , "x" , "[SEP]" ],
281
+ ["[PAD]" , "[CLS]" , "a" , "[SEP]" , "x" , "y" , "[SEP]" ],
282
+ ],
283
+ )
284
+ self .assertAllEqual (
285
+ segment_ids ,
286
+ [
287
+ [0 , 0 , 0 , 0 , 0 , 1 , 1 ],
288
+ [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
289
+ ],
290
+ )
291
+
146
292
def test_list_special_tokens (self ):
293
+ # right padding
147
294
seq1 = [["a" , "b" ], ["a" , "b" ]]
148
295
seq2 = [["x" , "y" ], ["x" ]]
149
296
packer = MultiSegmentPacker (
@@ -170,6 +317,32 @@ def test_list_special_tokens(self):
170
317
],
171
318
)
172
319
320
+ # left padding
321
+ packer = MultiSegmentPacker (
322
+ 8 ,
323
+ start_value = "<s>" ,
324
+ end_value = "</s>" ,
325
+ sep_value = ["</s>" , "</s>" ],
326
+ pad_value = "<pad>" ,
327
+ truncate = "round_robin" ,
328
+ padding_side = "left" ,
329
+ )
330
+ token_ids , segment_ids = packer ((seq1 , seq2 ))
331
+ self .assertAllEqual (
332
+ token_ids ,
333
+ [
334
+ ["<s>" , "a" , "b" , "</s>" , "</s>" , "x" , "y" , "</s>" ],
335
+ ["<pad>" , "<s>" , "a" , "b" , "</s>" , "</s>" , "x" , "</s>" ],
336
+ ],
337
+ )
338
+ self .assertAllEqual (
339
+ segment_ids ,
340
+ [
341
+ [0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 ],
342
+ [0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 ],
343
+ ],
344
+ )
345
+
173
346
def test_config (self ):
174
347
seq1 = [["a" , "b" , "c" ], ["a" , "b" ]]
175
348
seq2 = [["x" , "y" , "z" ], ["x" , "y" , "z" ]]
0 commit comments