Skip to content

Commit c314f88

Browse files
authored
implement of leftpadding (#2242)
* implement of leftpadding * add doc * update * fix * format * format * update test * add left padding for segment * add doc.
1 parent 7ab2c53 commit c314f88

File tree

5 files changed

+405
-9
lines changed

5 files changed

+405
-9
lines changed

keras_hub/src/layers/preprocessing/multi_segment_packer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
PreprocessingLayer,
44
)
55
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
6+
from keras_hub.src.utils.tensor_utils import pad
67
from keras_hub.src.utils.tensor_utils import preprocessing_function
78

89
try:
@@ -66,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer):
6667
"waterfall" algorithm that allocates quota in a
6768
left-to-right manner and fills up the buckets until we run
6869
out of budget. It support arbitrary number of segments.
70+
padding_side: str. Whether to pad the input on the "left" or "right".
71+
Defaults to "right".
6972
7073
Returns:
7174
A tuple with two elements. The first is the dense, packed token
@@ -124,6 +127,7 @@ def __init__(
124127
sep_value=None,
125128
pad_value=None,
126129
truncate="round_robin",
130+
padding_side="right",
127131
**kwargs,
128132
):
129133
super().__init__(**kwargs)
@@ -162,6 +166,7 @@ def check_special_value_type(value, value_name):
162166
self.end_value = end_value
163167

164168
self.pad_value = pad_value
169+
self.padding_side = padding_side
165170

166171
def get_config(self):
167172
config = super().get_config()
@@ -173,6 +178,7 @@ def get_config(self):
173178
"sep_value": self._sep_value,
174179
"pad_value": self.pad_value,
175180
"truncate": self.truncate,
181+
"padding_side": self.padding_side,
176182
}
177183
)
178184
return config
@@ -287,10 +293,18 @@ def call(
287293
# Pad to dense tensor output.
288294
sequence_length = sequence_length or self.sequence_length
289295
shape = tf.cast([-1, sequence_length], "int64")
290-
token_ids = token_ids.to_tensor(
291-
shape=shape, default_value=self.pad_value
296+
token_ids = pad(
297+
token_ids,
298+
shape=shape,
299+
padding_side=self.padding_side,
300+
pad_value=self.pad_value,
301+
)
302+
segment_ids = pad(
303+
segment_ids,
304+
shape=shape,
305+
padding_side=self.padding_side,
306+
pad_value=0,
292307
)
293-
segment_ids = segment_ids.to_tensor(shape=shape)
294308
# Remove the batch dim if added.
295309
if unbatched:
296310
token_ids = tf.squeeze(token_ids, 0)

keras_hub/src/layers/preprocessing/multi_segment_packer_test.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class MultiSegmentPackerTest(TestCase):
1010
def test_trim_single_input_ints(self):
11+
# right padding
1112
input_data = np.arange(3, 10)
1213
packer = MultiSegmentPacker(
1314
sequence_length=8, start_value=1, end_value=2
@@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
1617
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
1718
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
1819

20+
# left padding
21+
input_data = np.arange(3, 10)
22+
packer = MultiSegmentPacker(
23+
sequence_length=8,
24+
start_value=1,
25+
end_value=2,
26+
padding_side="left",
27+
)
28+
token_ids, segment_ids = packer(input_data)
29+
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
30+
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])
31+
1932
def test_trim_single_input_strings(self):
33+
# right padding
2034
input_data = ["a", "b", "c", "d"]
2135
packer = MultiSegmentPacker(
2236
sequence_length=5, start_value="[CLS]", end_value="[SEP]"
@@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
2539
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
2640
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
2741

42+
# left padding
43+
packer = MultiSegmentPacker(
44+
sequence_length=5,
45+
start_value="[CLS]",
46+
end_value="[SEP]",
47+
padding_side="left",
48+
)
49+
token_ids, segment_ids = packer(input_data)
50+
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
51+
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])
52+
2853
def test_trim_multiple_inputs_round_robin(self):
54+
# right padding
2955
seq1 = ["a", "b", "c"]
3056
seq2 = ["x", "y", "z"]
3157
packer = MultiSegmentPacker(
@@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
4066
)
4167
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
4268

69+
# left padding
70+
packer = MultiSegmentPacker(
71+
sequence_length=7,
72+
start_value="[CLS]",
73+
end_value="[SEP]",
74+
truncate="round_robin",
75+
padding_side="left",
76+
)
77+
token_ids, segment_ids = packer((seq1, seq2))
78+
self.assertAllEqual(
79+
token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"]
80+
)
81+
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])
82+
4383
def test_trim_multiple_inputs_waterfall(self):
84+
# right padding
4485
seq1 = ["a", "b", "c"]
4586
seq2 = ["x", "y", "z"]
4687
packer = MultiSegmentPacker(
@@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
5596
)
5697
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
5798

99+
# left padding
100+
packer = MultiSegmentPacker(
101+
sequence_length=7,
102+
start_value="[CLS]",
103+
end_value="[SEP]",
104+
truncate="waterfall",
105+
padding_side="left",
106+
)
107+
token_ids, segment_ids = packer((seq1, seq2))
108+
self.assertAllEqual(
109+
token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"]
110+
)
111+
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])
112+
58113
def test_trim_batched_inputs_round_robin(self):
114+
# right padding
59115
seq1 = [["a", "b", "c"], ["a", "b", "c"]]
60116
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
61117
packer = MultiSegmentPacker(
@@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
80136
],
81137
)
82138

139+
# left padding
140+
packer = MultiSegmentPacker(
141+
sequence_length=7,
142+
start_value="[CLS]",
143+
end_value="[SEP]",
144+
truncate="round_robin",
145+
padding_side="left",
146+
)
147+
token_ids, segment_ids = packer((seq1, seq2))
148+
self.assertAllEqual(
149+
token_ids,
150+
[
151+
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
152+
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
153+
],
154+
)
155+
self.assertAllEqual(
156+
segment_ids,
157+
[
158+
[0, 0, 0, 0, 1, 1, 1],
159+
[0, 0, 0, 0, 1, 1, 1],
160+
],
161+
)
162+
83163
def test_trim_batched_inputs_waterfall(self):
164+
# right padding
84165
seq1 = [["a", "b", "c"], ["a", "b"]]
85166
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
86167
packer = MultiSegmentPacker(
@@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
105186
],
106187
)
107188

189+
# left padding
190+
packer = MultiSegmentPacker(
191+
sequence_length=7,
192+
start_value="[CLS]",
193+
end_value="[SEP]",
194+
truncate="waterfall",
195+
padding_side="left",
196+
)
197+
token_ids, segment_ids = packer((seq1, seq2))
198+
self.assertAllEqual(
199+
token_ids,
200+
[
201+
["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"],
202+
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
203+
],
204+
)
205+
self.assertAllEqual(
206+
segment_ids,
207+
[
208+
[0, 0, 0, 0, 0, 1, 1],
209+
[0, 0, 0, 0, 1, 1, 1],
210+
],
211+
)
212+
108213
def test_pad_inputs(self):
214+
# right padding
109215
seq1 = ["a"]
110216
seq2 = ["x"]
111217
packer = MultiSegmentPacker(
@@ -118,7 +224,23 @@ def test_pad_inputs(self):
118224
)
119225
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])
120226

227+
# left padding
228+
packer = MultiSegmentPacker(
229+
6,
230+
start_value="[CLS]",
231+
end_value="[SEP]",
232+
pad_value="[PAD]",
233+
padding_side="left",
234+
)
235+
token_ids, segment_ids = packer((seq1, seq2))
236+
self.assertAllEqual(
237+
token_ids,
238+
["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
239+
)
240+
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])
241+
121242
def test_pad_batched_inputs(self):
243+
# right padding
122244
seq1 = [["a"], ["a"]]
123245
seq2 = [["x"], ["x", "y"]]
124246
packer = MultiSegmentPacker(
@@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
143265
],
144266
)
145267

268+
# left padding
269+
packer = MultiSegmentPacker(
270+
sequence_length=7,
271+
start_value="[CLS]",
272+
end_value="[SEP]",
273+
pad_value="[PAD]",
274+
padding_side="left",
275+
)
276+
token_ids, segment_ids = packer((seq1, seq2))
277+
self.assertAllEqual(
278+
token_ids,
279+
[
280+
["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
281+
["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
282+
],
283+
)
284+
self.assertAllEqual(
285+
segment_ids,
286+
[
287+
[0, 0, 0, 0, 0, 1, 1],
288+
[0, 0, 0, 0, 1, 1, 1],
289+
],
290+
)
291+
146292
def test_list_special_tokens(self):
293+
# right padding
147294
seq1 = [["a", "b"], ["a", "b"]]
148295
seq2 = [["x", "y"], ["x"]]
149296
packer = MultiSegmentPacker(
@@ -170,6 +317,32 @@ def test_list_special_tokens(self):
170317
],
171318
)
172319

320+
# left padding
321+
packer = MultiSegmentPacker(
322+
8,
323+
start_value="<s>",
324+
end_value="</s>",
325+
sep_value=["</s>", "</s>"],
326+
pad_value="<pad>",
327+
truncate="round_robin",
328+
padding_side="left",
329+
)
330+
token_ids, segment_ids = packer((seq1, seq2))
331+
self.assertAllEqual(
332+
token_ids,
333+
[
334+
["<s>", "a", "b", "</s>", "</s>", "x", "y", "</s>"],
335+
["<pad>", "<s>", "a", "b", "</s>", "</s>", "x", "</s>"],
336+
],
337+
)
338+
self.assertAllEqual(
339+
segment_ids,
340+
[
341+
[0, 0, 0, 0, 0, 1, 1, 1],
342+
[0, 0, 0, 0, 0, 0, 1, 1],
343+
],
344+
)
345+
173346
def test_config(self):
174347
seq1 = [["a", "b", "c"], ["a", "b"]]
175348
seq2 = [["x", "y", "z"], ["x", "y", "z"]]

0 commit comments

Comments
 (0)