Skip to content

implement of leftpadding #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions keras_hub/src/layers/preprocessing/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function

try:
Expand Down Expand Up @@ -66,6 +67,8 @@ class MultiSegmentPacker(PreprocessingLayer):
"waterfall" algorithm that allocates quota in a
left-to-right manner and fills up the buckets until we run
out of budget. It support arbitrary number of segments.
padding_side: str. Whether to pad the input on the "left" or "right".
Defaults to "right".

Returns:
A tuple with two elements. The first is the dense, packed token
Expand Down Expand Up @@ -124,6 +127,7 @@ def __init__(
sep_value=None,
pad_value=None,
truncate="round_robin",
padding_side="right",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a docstring

**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -162,6 +166,7 @@ def check_special_value_type(value, value_name):
self.end_value = end_value

self.pad_value = pad_value
self.padding_side = padding_side

def get_config(self):
config = super().get_config()
Expand All @@ -173,6 +178,7 @@ def get_config(self):
"sep_value": self._sep_value,
"pad_value": self.pad_value,
"truncate": self.truncate,
"padding_side": self.padding_side,
}
)
return config
Expand Down Expand Up @@ -287,10 +293,18 @@ def call(
# Pad to dense tensor output.
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
token_ids = token_ids.to_tensor(
shape=shape, default_value=self.pad_value
token_ids = pad(
token_ids,
shape=shape,
padding_side=self.padding_side,
pad_value=self.pad_value,
)
segment_ids = pad(
segment_ids,
shape=shape,
padding_side=self.padding_side,
pad_value=0,
)
segment_ids = segment_ids.to_tensor(shape=shape)
# Remove the batch dim if added.
if unbatched:
token_ids = tf.squeeze(token_ids, 0)
Expand Down
173 changes: 173 additions & 0 deletions keras_hub/src/layers/preprocessing/multi_segment_packer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class MultiSegmentPackerTest(TestCase):
def test_trim_single_input_ints(self):
# right padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8, start_value=1, end_value=2
Expand All @@ -16,7 +17,20 @@ def test_trim_single_input_ints(self):
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])

# left padding
input_data = np.arange(3, 10)
packer = MultiSegmentPacker(
sequence_length=8,
start_value=1,
end_value=2,
padding_side="left",
)
token_ids, segment_ids = packer(input_data)
self.assertAllEqual(token_ids, [1, 3, 4, 5, 6, 7, 8, 2])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 0, 0, 0])

def test_trim_single_input_strings(self):
# right padding
input_data = ["a", "b", "c", "d"]
packer = MultiSegmentPacker(
sequence_length=5, start_value="[CLS]", end_value="[SEP]"
Expand All @@ -25,7 +39,19 @@ def test_trim_single_input_strings(self):
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])

# left padding
packer = MultiSegmentPacker(
sequence_length=5,
start_value="[CLS]",
end_value="[SEP]",
padding_side="left",
)
token_ids, segment_ids = packer(input_data)
self.assertAllEqual(token_ids, ["[CLS]", "a", "b", "c", "[SEP]"])
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0])

def test_trim_multiple_inputs_round_robin(self):
# right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
Expand All @@ -40,7 +66,22 @@ def test_trim_multiple_inputs_round_robin(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids, ["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"]
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1, 1])

def test_trim_multiple_inputs_waterfall(self):
# right padding
seq1 = ["a", "b", "c"]
seq2 = ["x", "y", "z"]
packer = MultiSegmentPacker(
Expand All @@ -55,7 +96,22 @@ def test_trim_multiple_inputs_waterfall(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="waterfall",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids, ["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"]
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 0, 1, 1])

def test_trim_batched_inputs_round_robin(self):
# right padding
seq1 = [["a", "b", "c"], ["a", "b", "c"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
Expand All @@ -80,7 +136,32 @@ def test_trim_batched_inputs_round_robin(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_trim_batched_inputs_waterfall(self):
# right padding
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
packer = MultiSegmentPacker(
Expand All @@ -105,7 +186,32 @@ def test_trim_batched_inputs_waterfall(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
truncate="waterfall",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[CLS]", "a", "b", "c", "[SEP]", "x", "[SEP]"],
["[CLS]", "a", "b", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_pad_inputs(self):
# right padding
seq1 = ["a"]
seq2 = ["x"]
packer = MultiSegmentPacker(
Expand All @@ -118,7 +224,23 @@ def test_pad_inputs(self):
)
self.assertAllEqual(segment_ids, [0, 0, 0, 1, 1, 0])

# left padding
packer = MultiSegmentPacker(
6,
start_value="[CLS]",
end_value="[SEP]",
pad_value="[PAD]",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
["[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
)
self.assertAllEqual(segment_ids, [0, 0, 0, 0, 1, 1])

def test_pad_batched_inputs(self):
# right padding
seq1 = [["a"], ["a"]]
seq2 = [["x"], ["x", "y"]]
packer = MultiSegmentPacker(
Expand All @@ -143,7 +265,32 @@ def test_pad_batched_inputs(self):
],
)

# left padding
packer = MultiSegmentPacker(
sequence_length=7,
start_value="[CLS]",
end_value="[SEP]",
pad_value="[PAD]",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["[PAD]", "[PAD]", "[CLS]", "a", "[SEP]", "x", "[SEP]"],
["[PAD]", "[CLS]", "a", "[SEP]", "x", "y", "[SEP]"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 1, 1, 1],
],
)

def test_list_special_tokens(self):
# right padding
seq1 = [["a", "b"], ["a", "b"]]
seq2 = [["x", "y"], ["x"]]
packer = MultiSegmentPacker(
Expand All @@ -170,6 +317,32 @@ def test_list_special_tokens(self):
],
)

# left padding
packer = MultiSegmentPacker(
8,
start_value="<s>",
end_value="</s>",
sep_value=["</s>", "</s>"],
pad_value="<pad>",
truncate="round_robin",
padding_side="left",
)
token_ids, segment_ids = packer((seq1, seq2))
self.assertAllEqual(
token_ids,
[
["<s>", "a", "b", "</s>", "</s>", "x", "y", "</s>"],
["<pad>", "<s>", "a", "b", "</s>", "</s>", "x", "</s>"],
],
)
self.assertAllEqual(
segment_ids,
[
[0, 0, 0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1],
],
)

def test_config(self):
seq1 = [["a", "b", "c"], ["a", "b"]]
seq2 = [["x", "y", "z"], ["x", "y", "z"]]
Expand Down
Loading
Loading