From c323ab4d76310ff670b276515311d6f7e21e0c02 Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Fri, 25 Jul 2025 10:27:08 +0800 Subject: [PATCH 1/2] add left pad for CausalLMPreprocessor . --- .../preprocessing/multi_segment_packer.py | 7 +- .../layers/preprocessing/start_end_packer.py | 7 +- .../src/models/causal_lm_preprocessor.py | 88 +++++++++++++-- .../src/models/causal_lm_preprocessor_test.py | 102 ++++++++++++++++++ 4 files changed, 187 insertions(+), 17 deletions(-) diff --git a/keras_hub/src/layers/preprocessing/multi_segment_packer.py b/keras_hub/src/layers/preprocessing/multi_segment_packer.py index a4e0ba7ef4..d7b9ec8b65 100644 --- a/keras_hub/src/layers/preprocessing/multi_segment_packer.py +++ b/keras_hub/src/layers/preprocessing/multi_segment_packer.py @@ -281,9 +281,10 @@ def call( sequence_length=None, add_start_value=True, add_end_value=True, + padding_side=None, ): inputs, unbatched = self._sanitize_inputs(inputs) - + padding_side = padding_side or self.padding_side segments = self._trim_inputs(inputs) token_ids, segment_ids = self._combine_inputs( segments, @@ -296,13 +297,13 @@ def call( token_ids = pad( token_ids, shape=shape, - padding_side=self.padding_side, + padding_side=padding_side, pad_value=self.pad_value, ) segment_ids = pad( segment_ids, shape=shape, - padding_side=self.padding_side, + padding_side=padding_side, pad_value=0, ) # Remove the batch dim if added. diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py index efe10a4585..8efdfe6d7e 100644 --- a/keras_hub/src/layers/preprocessing/start_end_packer.py +++ b/keras_hub/src/layers/preprocessing/start_end_packer.py @@ -152,11 +152,12 @@ def call( sequence_length=None, add_start_value=True, add_end_value=True, + padding_side=None, ): inputs, unbatched, rectangular = convert_to_ragged_batch(inputs) x = inputs # Intermediate result. - batch_size = tf.shape(x)[0] + padding_side = padding_side or self.padding_side sequence_length = sequence_length or self.sequence_length dtype = inputs.dtype # Truncate. @@ -185,7 +186,7 @@ def call( outputs = pad( x, pad_value=self.pad_value, - padding_side=self.padding_side, + padding_side=padding_side, shape=(batch_size, sequence_length), ) outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs @@ -196,7 +197,7 @@ def call( mask = pad( mask, pad_value=False, - padding_side=self.padding_side, + padding_side=padding_side, shape=(batch_size, sequence_length), ) mask = tf.squeeze(mask, axis=0) if unbatched else mask diff --git a/keras_hub/src/models/causal_lm_preprocessor.py b/keras_hub/src/models/causal_lm_preprocessor.py index 3284e312cd..44246b2ff5 100644 --- a/keras_hub/src/models/causal_lm_preprocessor.py +++ b/keras_hub/src/models/causal_lm_preprocessor.py @@ -6,6 +6,11 @@ from keras_hub.src.utils.tensor_utils import preprocessing_function from keras_hub.src.utils.tensor_utils import strip_to_ragged +try: + from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice +except ImportError: + dynamic_update_slice = None + @keras_hub_export("keras_hub.models.CausalLMPreprocessor") class CausalLMPreprocessor(Preprocessor): @@ -64,6 +69,7 @@ def __init__( sequence_length=1024, add_start_token=True, add_end_token=True, + padding_side="right", **kwargs, ): super().__init__(**kwargs) @@ -72,6 +78,8 @@ def __init__( self.sequence_length = sequence_length self.add_start_token = add_start_token self.add_end_token = add_end_token + assert padding_side in ["right", "left"] + self.padding_side = padding_side def build(self, input_shape): # Defer packer creation to `build()` so that we can be sure tokenizer @@ -82,6 +90,7 @@ def build(self, input_shape): pad_value=self.tokenizer.pad_token_id, sequence_length=self.sequence_length, return_padding_mask=True, + padding_side=self.padding_side, ) self.built = True @@ -92,16 +101,44 @@ def call( y=None, sample_weight=None, sequence_length=None, + padding_side=None, ): - sequence_length = sequence_length or self.sequence_length x = self.tokenizer(x) - # Pad with one extra token to account for the truncation below. - token_ids, padding_mask = self.packer( - x, - sequence_length=sequence_length + 1, - add_start_value=self.add_start_token, - add_end_value=self.add_end_token, - ) + padding_side = padding_side or self.padding_side + sequence_length = sequence_length or self.sequence_length + if padding_side == "left": + addition_token_num = int(self.add_start_token + self.add_end_token) + token_ids, padding_mask = self.packer( + x, + sequence_length=x.to_tensor().shape[-1] + addition_token_num, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + padding_side=padding_side, + ) + token_ids, all_padding_mask = self.packer( + token_ids, + sequence_length=sequence_length + 1, + add_start_value=False, + add_end_value=False, + padding_side="right", + ) + if dynamic_update_slice is None: + raise ImportError( + "Left padding on CausalLMPreprocessor requires a TensorFlow " + "installation with XLA available." + ) + padding_mask = dynamic_update_slice( + all_padding_mask, padding_mask, [0] * len(padding_mask.shape) + ) + else: + # Pad with one extra token to account for the truncation below. + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length + 1, + add_start_value=self.add_start_token, + add_end_value=self.add_end_token, + padding_side=padding_side, + ) # The last token does not have a next token, so we truncate it out. x = { "token_ids": token_ids[..., :-1], @@ -116,6 +153,7 @@ def generate_preprocess( self, x, sequence_length=None, + padding_side=None, ): """Convert strings to integer token input for generation. @@ -130,11 +168,38 @@ def generate_preprocess( """ if not self.built: self.build(None) + padding_side = padding_side or self.padding_side x = self.tokenizer(x) - token_ids, padding_mask = self.packer( - x, sequence_length=sequence_length, add_end_value=False - ) + if padding_side == "left": + token_ids, padding_mask = self.packer( + x, + sequence_length=x.to_tensor().shape[-1] + 1, + add_end_value=False, + padding_side=padding_side, + ) + token_ids, all_padding_mask = self.packer( + token_ids, + sequence_length=sequence_length, + add_start_value=False, + add_end_value=False, + padding_side="right", + ) + if dynamic_update_slice is None: + raise ImportError( + "Left padding on CausalLMPreprocessor requires a TensorFlow " + "installation with XLA available." + ) + padding_mask = dynamic_update_slice( + all_padding_mask, padding_mask, [0] * len(padding_mask.shape) + ) + else: + token_ids, padding_mask = self.packer( + x, + sequence_length=sequence_length, + add_end_value=False, + padding_side=padding_side, + ) return { "token_ids": token_ids, "padding_mask": padding_mask, @@ -166,6 +231,7 @@ def get_config(self): "sequence_length": self.sequence_length, "add_start_token": self.add_start_token, "add_end_token": self.add_end_token, + "padding_side": self.padding_side, } ) return config diff --git a/keras_hub/src/models/causal_lm_preprocessor_test.py b/keras_hub/src/models/causal_lm_preprocessor_test.py index 8eb411a181..18600c983f 100644 --- a/keras_hub/src/models/causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/causal_lm_preprocessor_test.py @@ -17,6 +17,108 @@ def test_preset_accessors(self): self.assertTrue(bert_presets.isdisjoint(all_presets)) self.assertTrue(gpt2_presets.issubset(all_presets)) + def test_padding_side(self): + preprocessor = CausalLMPreprocessor.from_preset( + "gpt2_base_en", sequence_length=7 + ) + # left pad + outputs = preprocessor( + ["i love you", "this is keras hub"], padding_side="left" + ) + + self.assertAllEqual( + outputs[0]["token_ids"], + ( + [0, 0, 50256, 72, 1842, 345, 50256], + [50256, 5661, 318, 41927, 292, 12575, 50256], + ), + ) + self.assertAllEqual( + outputs[0]["padding_mask"], + ([0, 0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]), + ) + self.assertAllEqual( + outputs[1], + ( + [0, 50256, 72, 1842, 345, 50256, 0], + [5661, 318, 41927, 292, 12575, 50256, 0], + ), + ) + self.assertAllEqual( + outputs[2], + ( + [False, True, True, True, True, True, False], + [True, True, True, True, True, True, False], + ), + ) + # right pad + outputs = preprocessor( + ["i love you", "this is keras hub"], padding_side="right" + ) + self.assertAllEqual( + outputs[0]["token_ids"], + ( + [50256, 72, 1842, 345, 50256, 0, 0], + [50256, 5661, 318, 41927, 292, 12575, 50256], + ), + ) + self.assertAllEqual( + outputs[0]["padding_mask"], + ([1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1]), + ) + self.assertAllEqual( + outputs[1], + ( + [72, 1842, 345, 50256, 0, 0, 0], + [5661, 318, 41927, 292, 12575, 50256, 0], + ), + ) + self.assertAllEqual( + outputs[2], + ( + [True, True, True, True, False, False, False], + [True, True, True, True, True, True, False], + ), + ) + + def test_padding_side_generate(self): + preprocessor = CausalLMPreprocessor.from_preset( + "gpt2_base_en", sequence_length=7 + ) + # left pad + outputs = preprocessor.generate_preprocess( + ["i love you", "this is keras hub"], + padding_side="left", + sequence_length=7, + ) + self.assertAllEqual( + outputs["token_ids"], + ( + [0, 0, 50256, 72, 1842, 345, 0], + [50256, 5661, 318, 41927, 292, 12575, 0], + ), + ) + self.assertAllEqual( + outputs["padding_mask"], + ([[0, 0, 1, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1, 0]]), + ) + outputs = preprocessor.generate_preprocess( + ["i love you", "this is keras hub"], + padding_side="right", + sequence_length=7, + ) + self.assertAllEqual( + outputs["token_ids"], + ( + [50256, 72, 1842, 345, 0, 0, 0], + [50256, 5661, 318, 41927, 292, 12575, 0], + ), + ) + self.assertAllEqual( + outputs["padding_mask"], + ([[1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 0]]), + ) + @pytest.mark.large def test_from_preset(self): self.assertIsInstance( From bf135d5be6d2f00cae6c376dbf12ec091664e0fd Mon Sep 17 00:00:00 2001 From: pass_lin <935499957@qq.com> Date: Fri, 25 Jul 2025 10:33:31 +0800 Subject: [PATCH 2/2] format --- keras_hub/src/models/causal_lm_preprocessor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/keras_hub/src/models/causal_lm_preprocessor.py b/keras_hub/src/models/causal_lm_preprocessor.py index 44246b2ff5..46e2ff2b47 100644 --- a/keras_hub/src/models/causal_lm_preprocessor.py +++ b/keras_hub/src/models/causal_lm_preprocessor.py @@ -124,8 +124,7 @@ def call( ) if dynamic_update_slice is None: raise ImportError( - "Left padding on CausalLMPreprocessor requires a TensorFlow " - "installation with XLA available." + "Left padding on CausalLMPreprocessor requires TensorFlow" ) padding_mask = dynamic_update_slice( all_padding_mask, padding_mask, [0] * len(padding_mask.shape) @@ -187,8 +186,7 @@ def generate_preprocess( ) if dynamic_update_slice is None: raise ImportError( - "Left padding on CausalLMPreprocessor requires a TensorFlow " - "installation with XLA available." + "Left padding on CausalLMPreprocessor requires TensorFlow" ) padding_mask = dynamic_update_slice( all_padding_mask, padding_mask, [0] * len(padding_mask.shape)