Skip to content

Add left pad for CausalLMPreprocessor #2343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions keras_hub/src/layers/preprocessing/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,10 @@ def call(
sequence_length=None,
add_start_value=True,
add_end_value=True,
padding_side=None,
):
inputs, unbatched = self._sanitize_inputs(inputs)

padding_side = padding_side or self.padding_side
segments = self._trim_inputs(inputs)
token_ids, segment_ids = self._combine_inputs(
segments,
Expand All @@ -296,13 +297,13 @@ def call(
token_ids = pad(
token_ids,
shape=shape,
padding_side=self.padding_side,
padding_side=padding_side,
pad_value=self.pad_value,
)
segment_ids = pad(
segment_ids,
shape=shape,
padding_side=self.padding_side,
padding_side=padding_side,
pad_value=0,
)
# Remove the batch dim if added.
Expand Down
7 changes: 4 additions & 3 deletions keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,12 @@ def call(
sequence_length=None,
add_start_value=True,
add_end_value=True,
padding_side=None,
):
inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
x = inputs # Intermediate result.

batch_size = tf.shape(x)[0]
padding_side = padding_side or self.padding_side
sequence_length = sequence_length or self.sequence_length
dtype = inputs.dtype
# Truncate.
Expand Down Expand Up @@ -185,7 +186,7 @@ def call(
outputs = pad(
x,
pad_value=self.pad_value,
padding_side=self.padding_side,
padding_side=padding_side,
shape=(batch_size, sequence_length),
)
outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs
Expand All @@ -196,7 +197,7 @@ def call(
mask = pad(
mask,
pad_value=False,
padding_side=self.padding_side,
padding_side=padding_side,
shape=(batch_size, sequence_length),
)
mask = tf.squeeze(mask, axis=0) if unbatched else mask
Expand Down
86 changes: 75 additions & 11 deletions keras_hub/src/models/causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from keras_hub.src.utils.tensor_utils import preprocessing_function
from keras_hub.src.utils.tensor_utils import strip_to_ragged

try:
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
except ImportError:
dynamic_update_slice = None


@keras_hub_export("keras_hub.models.CausalLMPreprocessor")
class CausalLMPreprocessor(Preprocessor):
Expand Down Expand Up @@ -64,6 +69,7 @@ def __init__(
sequence_length=1024,
add_start_token=True,
add_end_token=True,
padding_side="right",
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -72,6 +78,8 @@ def __init__(
self.sequence_length = sequence_length
self.add_start_token = add_start_token
self.add_end_token = add_end_token
assert padding_side in ["right", "left"]
self.padding_side = padding_side

def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
Expand All @@ -82,6 +90,7 @@ def build(self, input_shape):
pad_value=self.tokenizer.pad_token_id,
sequence_length=self.sequence_length,
return_padding_mask=True,
padding_side=self.padding_side,
)
self.built = True

Expand All @@ -92,16 +101,43 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
padding_side=None,
):
sequence_length = sequence_length or self.sequence_length
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
padding_side = padding_side or self.padding_side
sequence_length = sequence_length or self.sequence_length
if padding_side == "left":
addition_token_num = int(self.add_start_token + self.add_end_token)
token_ids, padding_mask = self.packer(
x,
sequence_length=x.to_tensor().shape[-1] + addition_token_num,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using x.to_tensor().shape[-1] can be inefficient because it creates a dense tensor from a ragged one just to get its shape. A more performant way to get the length of the longest sequence in a tf.RaggedTensor is to use x.bounding_shape()[-1]. This avoids the overhead of creating the intermediate dense tensor. This optimization also applies to line 177 in generate_preprocess.

Suggested change
sequence_length=x.to_tensor().shape[-1] + addition_token_num,
sequence_length=x.bounding_shape()[-1] + addition_token_num,

add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
padding_side=padding_side,
)
token_ids, all_padding_mask = self.packer(
token_ids,
sequence_length=sequence_length + 1,
add_start_value=False,
add_end_value=False,
padding_side="right",
)
if dynamic_update_slice is None:
raise ImportError(
"Left padding on CausalLMPreprocessor requires TensorFlow"
)
padding_mask = dynamic_update_slice(
all_padding_mask, padding_mask, [0] * len(padding_mask.shape)
)
else:
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
padding_side=padding_side,
)
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
Expand All @@ -116,6 +152,7 @@ def generate_preprocess(
self,
x,
sequence_length=None,
padding_side=None,
):
"""Convert strings to integer token input for generation.

Expand All @@ -130,11 +167,37 @@ def generate_preprocess(
"""
if not self.built:
self.build(None)
padding_side = padding_side or self.padding_side

x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
if padding_side == "left":
token_ids, padding_mask = self.packer(
x,
sequence_length=x.to_tensor().shape[-1] + 1,
add_end_value=False,
padding_side=padding_side,
)
token_ids, all_padding_mask = self.packer(
token_ids,
sequence_length=sequence_length,
add_start_value=False,
add_end_value=False,
padding_side="right",
)
if dynamic_update_slice is None:
raise ImportError(
"Left padding on CausalLMPreprocessor requires TensorFlow"
)
padding_mask = dynamic_update_slice(
all_padding_mask, padding_mask, [0] * len(padding_mask.shape)
)
else:
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length,
add_end_value=False,
padding_side=padding_side,
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
Expand Down Expand Up @@ -166,6 +229,7 @@ def get_config(self):
"sequence_length": self.sequence_length,
"add_start_token": self.add_start_token,
"add_end_token": self.add_end_token,
"padding_side": self.padding_side,
}
)
return config
Expand Down
102 changes: 102 additions & 0 deletions keras_hub/src/models/causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,108 @@ def test_preset_accessors(self):
self.assertTrue(bert_presets.isdisjoint(all_presets))
self.assertTrue(gpt2_presets.issubset(all_presets))

def test_padding_side(self):
preprocessor = CausalLMPreprocessor.from_preset(
"gpt2_base_en", sequence_length=7
)
# left pad
outputs = preprocessor(
["i love you", "this is keras hub"], padding_side="left"
)

self.assertAllEqual(
outputs[0]["token_ids"],
(
[0, 0, 50256, 72, 1842, 345, 50256],
[50256, 5661, 318, 41927, 292, 12575, 50256],
),
)
self.assertAllEqual(
outputs[0]["padding_mask"],
([0, 0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]),
)
self.assertAllEqual(
outputs[1],
(
[0, 50256, 72, 1842, 345, 50256, 0],
[5661, 318, 41927, 292, 12575, 50256, 0],
),
)
self.assertAllEqual(
outputs[2],
(
[False, True, True, True, True, True, False],
[True, True, True, True, True, True, False],
),
)
# right pad
outputs = preprocessor(
["i love you", "this is keras hub"], padding_side="right"
)
self.assertAllEqual(
outputs[0]["token_ids"],
(
[50256, 72, 1842, 345, 50256, 0, 0],
[50256, 5661, 318, 41927, 292, 12575, 50256],
),
)
self.assertAllEqual(
outputs[0]["padding_mask"],
([1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1]),
)
self.assertAllEqual(
outputs[1],
(
[72, 1842, 345, 50256, 0, 0, 0],
[5661, 318, 41927, 292, 12575, 50256, 0],
),
)
self.assertAllEqual(
outputs[2],
(
[True, True, True, True, False, False, False],
[True, True, True, True, True, True, False],
),
)

def test_padding_side_generate(self):
preprocessor = CausalLMPreprocessor.from_preset(
"gpt2_base_en", sequence_length=7
)
# left pad
outputs = preprocessor.generate_preprocess(
["i love you", "this is keras hub"],
padding_side="left",
sequence_length=7,
)
self.assertAllEqual(
outputs["token_ids"],
(
[0, 0, 50256, 72, 1842, 345, 0],
[50256, 5661, 318, 41927, 292, 12575, 0],
),
)
self.assertAllEqual(
outputs["padding_mask"],
([[0, 0, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0]]),
)
outputs = preprocessor.generate_preprocess(
["i love you", "this is keras hub"],
padding_side="right",
sequence_length=7,
)
self.assertAllEqual(
outputs["token_ids"],
(
[50256, 72, 1842, 345, 0, 0, 0],
[50256, 5661, 318, 41927, 292, 12575, 0],
),
)
self.assertAllEqual(
outputs["padding_mask"],
([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0]]),
)

@pytest.mark.large
def test_from_preset(self):
self.assertIsInstance(
Expand Down
Loading