59
59
SPLIT_PATTERN_2 = rf"""[\s६{ SPECIAL_WHITESPACES } ]$"""
60
60
61
61
62
- def get_unsplittable_tokens_pattern ( unsplittable_tokens ):
63
- if unsplittable_tokens is None or len (unsplittable_tokens ) == 0 :
62
+ def get_special_tokens_pattern ( special_tokens ):
63
+ if special_tokens is None or len (special_tokens ) == 0 :
64
64
return None
65
- return r"|" .join ([re .escape (token ) for token in unsplittable_tokens ])
65
+ return r"|" .join ([re .escape (token ) for token in special_tokens ])
66
66
67
67
68
68
def bytes_to_unicode ():
@@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
97
97
return result
98
98
99
99
100
- def split_strings_for_bpe (inputs , unsplittable_tokens_pattern = None ):
100
+ def split_strings_for_bpe (inputs , special_tokens_pattern = None ):
101
101
# We need to recreate the exact behavior of token presplitting in the
102
102
# original gpt2 tokenizer which uses a lookahead. As re2 does not
103
103
# support lookahead match, we are using an alternative insert a special
@@ -110,23 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
110
110
inputs , rf"(\s{ SPECIAL_WHITESPACES } )$" , r"\1६"
111
111
)
112
112
113
- if unsplittable_tokens_pattern is not None :
114
- # First split the unsplittable tokens from the input.
113
+ if special_tokens_pattern is not None :
114
+ # First split the special tokens from the input.
115
115
raw_tokens = tf_text .regex_split (
116
- inputs , unsplittable_tokens_pattern , unsplittable_tokens_pattern
116
+ inputs , special_tokens_pattern , special_tokens_pattern
117
117
)
118
- # Then split using both `unsplittable_tokens_pattern ` and
118
+ # Then split using both `special_tokens_pattern ` and
119
119
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
120
- # affecting the unsplittable tokens.
121
- # We split unsplittable tokens first then apply this split instead of
120
+ # affecting the special tokens.
121
+ # We split special tokens first then apply this split instead of
122
122
# applying this split directly, because otherwise we will not split
123
- # unsplittable tokens from inputs properly, because of this pattern
123
+ # special tokens from inputs properly, because of this pattern
124
124
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
125
125
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
126
126
raw_tokens = tf_text .regex_split (
127
127
raw_tokens ,
128
- r"|" .join ([unsplittable_tokens_pattern , SPLIT_PATTERN_1 ]),
129
- r"|" .join ([unsplittable_tokens_pattern , SPLIT_PATTERN_1 ]),
128
+ r"|" .join ([special_tokens_pattern , SPLIT_PATTERN_1 ]),
129
+ r"|" .join ([special_tokens_pattern , SPLIT_PATTERN_1 ]),
130
130
)
131
131
raw_tokens = raw_tokens .merge_dims (- 2 , - 1 )
132
132
else :
@@ -238,16 +238,16 @@ class BytePairTokenizer(tokenizer.Tokenizer):
238
238
a prefix space to the first word will cause it to be tokenized
239
239
equivalently to all subsequent words in the sequence.
240
240
Defaults to `False`.
241
- unsplittable_tokens : list. A list of unsplittable tokens. when
242
- `unsplittable_tokens_in_strings ` is set to `True`, unsplittable
241
+ special_tokens : list. A list of special tokens. when
242
+ `special_tokens_in_strings ` is set to `True`, special
243
243
tokens will never be split during the word-level splitting applied
244
244
before the byte-pair encoding. This can be used to ensure special
245
245
tokens map to unique indices in the vocabulary, even if these
246
- unsplittable tokens contain splittable characters such as
247
- punctuation. Unsplittable tokens must still be included in
246
+ special tokens contain splittable characters such as
247
+ punctuation. special tokens must still be included in
248
248
`vocabulary`. Defaults to `None`.
249
- unsplittable_tokens_in_strings : bool. To indicate if the tokenizer
250
- should expect unsplittable tokens in input strings that should be
249
+ special_tokens_in_strings : bool. To indicate if the tokenizer
250
+ should expect special tokens in input strings that should be
251
251
tokenized and mapped correctly to their ids. Defaults to False.
252
252
253
253
Examples:
@@ -287,8 +287,8 @@ def __init__(
287
287
merges = None ,
288
288
sequence_length = None ,
289
289
add_prefix_space = False ,
290
- unsplittable_tokens = None ,
291
- unsplittable_tokens_in_strings = False ,
290
+ special_tokens = None ,
291
+ special_tokens_in_strings = False ,
292
292
dtype = "int32" ,
293
293
** kwargs ,
294
294
) -> None :
@@ -303,11 +303,11 @@ def __init__(
303
303
super ().__init__ (dtype = dtype , ** kwargs )
304
304
self .sequence_length = sequence_length
305
305
self .add_prefix_space = add_prefix_space
306
- self .unsplittable_tokens = unsplittable_tokens
307
- self ._unsplittable_tokens_pattern = None
308
- if unsplittable_tokens_in_strings :
309
- self ._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern (
310
- unsplittable_tokens
306
+ self .special_tokens = special_tokens
307
+ self ._special_tokens_pattern = None
308
+ if special_tokens_in_strings :
309
+ self ._special_tokens_pattern = get_special_tokens_pattern (
310
+ special_tokens
311
311
)
312
312
313
313
# Create byte <=> unicode mapping. This is useful for handling
@@ -362,8 +362,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
362
362
)
363
363
364
364
# Check for special tokens in vocabulary.
365
- if self .unsplittable_tokens is not None :
366
- for token in self .unsplittable_tokens :
365
+ if self .special_tokens is not None :
366
+ for token in self .special_tokens :
367
367
if token not in self .get_vocabulary ():
368
368
raise ValueError (
369
369
f"Cannot find token `'{ token } '` in the provided "
@@ -383,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
383
383
)
384
384
385
385
self .cache = BytePairTokenizerCache ()
386
- if self .unsplittable_tokens :
386
+ if self .special_tokens and self . _special_tokens_pattern is not None :
387
387
# Put special tokens into cache, so it won't be further split and
388
388
# merged.
389
- self .cache .insert (
390
- self .unsplittable_tokens , self .unsplittable_tokens
391
- )
389
+ self .cache .insert (self .special_tokens , self .special_tokens )
392
390
393
391
# Create mapping between string tokens to int ids, and vice versa.
394
392
byte_pairs = [x [0 ] for x in self .vocabulary .items ()]
@@ -566,9 +564,7 @@ def tokenize(self, inputs):
566
564
if scalar_input :
567
565
inputs = tf .expand_dims (inputs , 0 )
568
566
569
- raw_tokens = split_strings_for_bpe (
570
- inputs , self ._unsplittable_tokens_pattern
571
- )
567
+ raw_tokens = split_strings_for_bpe (inputs , self ._special_tokens_pattern )
572
568
token_row_splits = raw_tokens .row_splits
573
569
flat_tokens = raw_tokens .flat_values
574
570
@@ -662,7 +658,7 @@ def get_config(self):
662
658
{
663
659
"sequence_length" : self .sequence_length ,
664
660
"add_prefix_space" : self .add_prefix_space ,
665
- "unsplittable_tokens " : self .unsplittable_tokens ,
661
+ "special_tokens " : self .special_tokens ,
666
662
}
667
663
)
668
664
return config
0 commit comments