Skip to content

Commit 0efbb6e

Browse files
authored
fix GPT2 token's special_tokens_mask when used with add_bos_token=True (#19036)
1 parent 0e24548 commit 0efbb6e

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/transformers/models/gpt2/tokenization_gpt2.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,38 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
261261

262262
return output + bos_token_ids + token_ids_1
263263

264+
def get_special_tokens_mask(
265+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
266+
) -> List[int]:
267+
"""
268+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
269+
special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
270+
271+
Args:
272+
token_ids_0 (`List[int]`):
273+
List of IDs.
274+
token_ids_1 (`List[int]`, *optional*):
275+
Optional second list of IDs for sequence pairs.
276+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
277+
Whether or not the token list is already formatted with special tokens for the model.
278+
279+
Returns:
280+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
281+
"""
282+
if already_has_special_tokens:
283+
return super().get_special_tokens_mask(
284+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
285+
)
286+
287+
if not self.add_bos_token:
288+
return super().get_special_tokens_mask(
289+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=False
290+
)
291+
292+
if token_ids_1 is None:
293+
return [1] + ([0] * len(token_ids_0))
294+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1))
295+
264296
def _tokenize(self, text):
265297
"""Tokenize a string."""
266298
bpe_tokens = []

tests/models/gpt2/test_tokenization_gpt2.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,28 @@ def test_add_bos_token_slow(self):
250250
# tokenizer has no padding token
251251
def test_padding_different_model_input_name(self):
252252
pass
253+
254+
def test_special_tokens_mask_input_pairs_and_bos_token(self):
255+
# TODO: change to self.get_tokenizers() when the fast version is implemented
256+
tokenizers = [self.get_tokenizer(do_lower_case=False, add_bos_token=True)]
257+
for tokenizer in tokenizers:
258+
with self.subTest(f"{tokenizer.__class__.__name__}"):
259+
sequence_0 = "Encode this."
260+
sequence_1 = "This one too please."
261+
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
262+
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
263+
encoded_sequence_dict = tokenizer.encode_plus(
264+
sequence_0,
265+
sequence_1,
266+
add_special_tokens=True,
267+
return_special_tokens_mask=True,
268+
)
269+
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
270+
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
271+
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
272+
273+
filtered_sequence = [
274+
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
275+
]
276+
filtered_sequence = [x for x in filtered_sequence if x is not None]
277+
self.assertEqual(encoded_sequence, filtered_sequence)

0 commit comments

Comments
 (0)