-
Couldn't load subscription status.
- Fork 307
implement of leftpadding #2242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement of leftpadding #2242
Changes from 6 commits
299102c
59627a4
5d1b2c0
97cada7
85bb256
6ab5ea9
f1c55ac
8c40279
7ceef48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,8 @@ class StartEndPacker(PreprocessingLayer): | |
| 0 or "" will be added depending on the dtype of the input tensor. | ||
| return_padding_mask: bool. Whether to return a boolean padding mask of | ||
| all locations that are filled in with the `pad_value`. | ||
| padding_side: str. Whether to pad the input on the "left" or "right". | ||
| Defaults to "right". | ||
|
|
||
| Call arguments: | ||
| inputs: A `tf.Tensor`, `tf.RaggedTensor`, or list of python strings. | ||
|
|
@@ -111,6 +113,7 @@ def __init__( | |
| pad_value=None, | ||
| return_padding_mask=False, | ||
| name=None, | ||
| padding_side="right", | ||
| **kwargs, | ||
| ): | ||
| super().__init__(name=name, **kwargs) | ||
|
|
@@ -139,6 +142,20 @@ def check_special_value_type(value, value_name): | |
|
|
||
| self.pad_value = pad_value | ||
| self.return_padding_mask = return_padding_mask | ||
| self.padding_side = padding_side | ||
|
|
||
| def pad(self, x, shape, pad_value): | ||
|
||
| if self.padding_side == "left": | ||
| x = x[..., ::-1] | ||
|
|
||
| outputs = x.to_tensor( | ||
| default_value=pad_value, | ||
| shape=shape, | ||
| ) | ||
|
|
||
| if self.padding_side == "left": | ||
| outputs = outputs[..., ::-1] | ||
| return outputs | ||
|
|
||
| @preprocessing_function | ||
| def call( | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -154,6 +171,13 @@ def call( | |
| batch_size = tf.shape(x)[0] | ||
| sequence_length = sequence_length or self.sequence_length | ||
| dtype = inputs.dtype | ||
| # Truncate. | ||
| truncation_length = sequence_length | ||
| if add_start_value and self.start_value is not None: | ||
| truncation_length -= len(self.start_value) | ||
| if add_end_value and self.end_value is not None: | ||
| truncation_length -= len(self.end_value) | ||
| x = x[..., :truncation_length] | ||
|
|
||
| # Concatenate start and end tokens. | ||
| if add_start_value and self.start_value is not None: | ||
|
|
@@ -167,23 +191,26 @@ def call( | |
| end_token_id_tensor = tf.repeat( | ||
| end_value[tf.newaxis, :], repeats=batch_size, axis=0 | ||
| ) | ||
| # Trim to leave room for end token. | ||
| x = x[..., : sequence_length - len(self.end_value)] | ||
| x = tf.concat([x, end_token_id_tensor], axis=-1) | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Pad to desired length. | ||
| outputs = x.to_tensor( | ||
| default_value=self.pad_value, | ||
| outputs = self.pad( | ||
| x, | ||
| shape=(batch_size, sequence_length), | ||
| pad_value=self.pad_value, | ||
| ) | ||
| outputs = tf.squeeze(outputs, axis=0) if unbatched else outputs | ||
|
|
||
| if self.return_padding_mask: | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| mask = tf.ones_like(x, dtype="bool") | ||
| mask = mask.to_tensor(shape=(batch_size, sequence_length)) | ||
|
|
||
| mask = self.pad( | ||
| mask, | ||
| shape=(batch_size, sequence_length), | ||
| pad_value=False, | ||
| ) | ||
| mask = tf.squeeze(mask, axis=0) if unbatched else mask | ||
| return outputs, mask | ||
|
|
||
| return outputs | ||
|
|
||
| def get_config(self): | ||
|
|
@@ -195,6 +222,7 @@ def get_config(self): | |
| "end_value": self._end_value, | ||
| "pad_value": self.pad_value, | ||
| "return_padding_mask": self.return_padding_mask, | ||
| "padding_side": self.padding_side, | ||
| } | ||
| ) | ||
| return config | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,11 +6,19 @@ | |
|
|
||
| class StartEndPackerTest(TestCase): | ||
mattdangerw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def test_dense_input(self): | ||
| # right padding | ||
| input_data = [5, 6, 7] | ||
| start_end_packer = StartEndPacker(sequence_length=5) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [5, 6, 7, 0, 0] | ||
| self.assertAllEqual(output, expected_output) | ||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, padding_side="left" | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [0, 0, 5, 6, 7] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_bfloat16_dtype(self): | ||
| # Core Keras has a strange bug where it converts int to floats in | ||
|
|
@@ -21,29 +29,54 @@ def test_bfloat16_dtype(self): | |
| self.assertDTypeEqual(output, "int32") | ||
|
|
||
| def test_dense_2D_input(self): | ||
| # right padding | ||
| input_data = [[5, 6, 7]] | ||
| start_end_packer = StartEndPacker(sequence_length=5) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[5, 6, 7, 0, 0]] | ||
| self.assertAllEqual(output, expected_output) | ||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, padding_side="left" | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[0, 0, 5, 6, 7]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_ragged_input(self): | ||
| # right padding | ||
| input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
| start_end_packer = StartEndPacker(sequence_length=5) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]] | ||
| self.assertAllEqual(output, expected_output) | ||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, padding_side="left" | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_start_end_token(self): | ||
| # right padding | ||
| input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, start_value=1, end_value=2 | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] | ||
| self.assertAllEqual(output, expected_output) | ||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, start_value=1, end_value=2, padding_side="left" | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_multiple_start_end_tokens(self): | ||
| # right padding | ||
| input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=8, | ||
|
|
@@ -55,7 +88,20 @@ def test_multiple_start_end_tokens(self): | |
| expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=8, | ||
| start_value=[1, 2], | ||
| end_value=[3, 4], | ||
| pad_value=0, | ||
| padding_side="left", | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_start_end_padding_value(self): | ||
| # right padding | ||
| input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=7, start_value=1, end_value=2, pad_value=3 | ||
|
|
@@ -64,7 +110,58 @@ def test_start_end_padding_value(self): | |
| expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=7, | ||
| start_value=1, | ||
| end_value=2, | ||
| pad_value=3, | ||
| padding_side="left", | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_truncation_side_flips(self): | ||
|
||
| # right padding | ||
| input_data = list(range(10)) | ||
| packer = StartEndPacker( | ||
| sequence_length=7, | ||
| start_value=98, | ||
| end_value=99, | ||
| ) | ||
| expected_output = [98, 0, 1, 2, 3, 4, 99] | ||
| self.assertAllEqual(packer(input_data), expected_output) | ||
|
|
||
| # left padding | ||
| packer = StartEndPacker( | ||
| sequence_length=7, | ||
| start_value=98, | ||
| end_value=99, | ||
| padding_side="left", | ||
| ) | ||
| self.assertAllEqual(packer(input_data), expected_output) | ||
|
|
||
| def test_truncation_side_flips_wo_endvalue(self): | ||
| # right padding | ||
| input_data = list(range(10)) | ||
| packer = StartEndPacker( | ||
| sequence_length=7, | ||
| start_value=98, | ||
| ) | ||
| expected_output = [98, 0, 1, 2, 3, 4, 5] | ||
| self.assertAllEqual(packer(input_data), expected_output) | ||
|
|
||
| # left padding | ||
| packer = StartEndPacker( | ||
| sequence_length=7, | ||
| start_value=98, | ||
| padding_side="left", | ||
| ) | ||
| self.assertAllEqual(packer(input_data), expected_output) | ||
|
|
||
| def test_end_token_value_during_truncation(self): | ||
| # right padding | ||
| input_data = [[5, 6], [8, 9, 10, 11, 12, 13]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, start_value=1, end_value=2, pad_value=0 | ||
|
|
@@ -73,7 +170,20 @@ def test_end_token_value_during_truncation(self): | |
| expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, | ||
| start_value=1, | ||
| end_value=2, | ||
| pad_value=0, | ||
| padding_side="left", | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_string_input(self): | ||
| # right padding | ||
| input_data = [["KerasHub", "is", "awesome"], ["amazing"]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, | ||
|
|
@@ -88,7 +198,23 @@ def test_string_input(self): | |
| ] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=5, | ||
| start_value="[START]", | ||
| end_value="[END]", | ||
| pad_value="[PAD]", | ||
| padding_side="left", | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [ | ||
| ["[START]", "KerasHub", "is", "awesome", "[END]"], | ||
| ["[PAD]", "[PAD]", "[START]", "amazing", "[END]"], | ||
| ] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_string_input_with_multiple_special_values(self): | ||
| # right padding | ||
| input_data = [["KerasHub", "is", "awesome"], ["amazing"]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, | ||
|
|
@@ -103,6 +229,21 @@ def test_string_input_with_multiple_special_values(self): | |
| ] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| # left padding | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, | ||
| start_value=["[END]", "[START]"], | ||
| end_value="[END]", | ||
| pad_value="[PAD]", | ||
| padding_side="left", | ||
| ) | ||
| output = start_end_packer(input_data) | ||
| expected_output = [ | ||
| ["[END]", "[START]", "KerasHub", "is", "awesome", "[END]"], | ||
| ["[PAD]", "[PAD]", "[END]", "[START]", "amazing", "[END]"], | ||
| ] | ||
| self.assertAllEqual(output, expected_output) | ||
|
|
||
| def test_special_token_dtype_error(self): | ||
| with self.assertRaises(ValueError): | ||
| StartEndPacker(sequence_length=5, start_value=1.0) | ||
|
|
@@ -147,3 +288,39 @@ def test_get_config(self): | |
| } | ||
|
|
||
| self.assertEqual(config, {**config, **expected_config_subset}) | ||
|
|
||
| def test_return_padding_mask_right_padding(self): | ||
|
||
| input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, | ||
| start_value=1, | ||
| end_value=2, | ||
| return_padding_mask=True, | ||
| ) | ||
| output, padding_mask = start_end_packer(input_data) | ||
| expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]] | ||
| expected_padding_mask = [ | ||
| [True, True, True, True, True, False], | ||
| [True, True, True, True, True, True], | ||
| ] | ||
| print(padding_mask) | ||
| self.assertAllEqual(output, expected_output) | ||
| self.assertAllEqual(padding_mask, expected_padding_mask) | ||
|
|
||
| def test_return_padding_mask_left_padding(self): | ||
| input_data = [[5, 6, 7], [8, 9, 10, 11]] | ||
| start_end_packer = StartEndPacker( | ||
| sequence_length=6, | ||
| start_value=1, | ||
| end_value=2, | ||
| return_padding_mask=True, | ||
| padding_side="left", | ||
| ) | ||
| output, padding_mask = start_end_packer(input_data) | ||
| expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]] | ||
| expected_padding_mask = [ | ||
| [False, True, True, True, True, True], | ||
| [True, True, True, True, True, True], | ||
| ] | ||
| self.assertAllEqual(output, expected_output) | ||
| self.assertAllEqual(padding_mask, expected_padding_mask) | ||
Uh oh!
There was an error while loading. Please reload this page.