59
59
SPLIT_PATTERN_2 = rf"""[\s६{ SPECIAL_WHITESPACES } ]$"""
60
60
61
61
62
- def get_unsplittable_tokens_pattern ( unsplittable_tokens ):
63
- if unsplittable_tokens is None or len (unsplittable_tokens ) == 0 :
62
+ def get_special_tokens_pattern ( special_tokens ):
63
+ if special_tokens is None or len (special_tokens ) == 0 :
64
64
return None
65
- return r"|" .join ([re .escape (token ) for token in unsplittable_tokens ])
65
+ return r"|" .join ([re .escape (token ) for token in special_tokens ])
66
66
67
67
68
68
def bytes_to_unicode ():
@@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
97
97
return result
98
98
99
99
100
- def split_strings_for_bpe (inputs , unsplittable_tokens_pattern = None ):
100
+ def split_strings_for_bpe (inputs , special_tokens_pattern = None ):
101
101
# We need to recreate the exact behavior of token presplitting in the
102
102
# original gpt2 tokenizer which uses a lookahead. As re2 does not
103
103
# support lookahead match, we are using an alternative insert a special
@@ -110,26 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
110
110
inputs , rf"(\s{ SPECIAL_WHITESPACES } )$" , r"\1६"
111
111
)
112
112
113
- if unsplittable_tokens_pattern is not None :
114
- # First split the unsplittable tokens from the input.
113
+ if special_tokens_pattern is not None :
114
+ # First split the special tokens from the input.
115
115
raw_tokens = tf_text .regex_split (
116
- inputs , unsplittable_tokens_pattern , unsplittable_tokens_pattern
116
+ inputs , special_tokens_pattern , special_tokens_pattern
117
117
)
118
- split_pattern_1_with_unsplittable_tokens = r"|" .join (
119
- [unsplittable_tokens_pattern , SPLIT_PATTERN_1 ]
120
- )
121
- # Then split using both `unsplittable_tokens_pattern` and
118
+ # Then split using both `special_tokens_pattern` and
122
119
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
123
- # affecting the unsplittable tokens.
124
- # We split unsplittable tokens first then apply this split instead of
120
+ # affecting the special tokens.
121
+ # We split special tokens first then apply this split instead of
125
122
# applying this split directly, because otherwise we will not split
126
- # unsplittable tokens from inputs properly, because of this pattern
123
+ # special tokens from inputs properly, because of this pattern
127
124
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
128
125
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
129
126
raw_tokens = tf_text .regex_split (
130
127
raw_tokens ,
131
- split_pattern_1_with_unsplittable_tokens ,
132
- split_pattern_1_with_unsplittable_tokens ,
128
+ r"|" . join ([ special_tokens_pattern , SPLIT_PATTERN_1 ]) ,
129
+ r"|" . join ([ special_tokens_pattern , SPLIT_PATTERN_1 ]) ,
133
130
)
134
131
raw_tokens = raw_tokens .merge_dims (- 2 , - 1 )
135
132
else :
@@ -241,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
241
238
a prefix space to the first word will cause it to be tokenized
242
239
equivalently to all subsequent words in the sequence.
243
240
Defaults to `False`.
244
- unsplittable_tokens: list. A list of strings that will
245
- never be split during the word-level splitting applied before the
246
- byte-pair encoding. This can be used to ensure special tokens map to
247
- unique indices in the vocabulary, even if these special tokens
248
- contain splittable characters such as punctuation. Special tokens
249
- must still be included in `vocabulary`. Defaults to `None`.
241
+ special_tokens: list. A list of special tokens. when
242
+ `special_tokens_in_strings` is set to `True`, special
243
+ tokens will never be split during the word-level splitting applied
244
+ before the byte-pair encoding. This can be used to ensure special
245
+ tokens map to unique indices in the vocabulary, even if these
246
+ special tokens contain splittable characters such as
247
+ punctuation. special tokens must still be included in
248
+ `vocabulary`. Defaults to `None`.
249
+ special_tokens_in_strings: bool. To indicate if the tokenizer
250
+ should expect special tokens in input strings that should be
251
+ tokenized and mapped correctly to their ids. Defaults to False.
250
252
251
253
Examples:
252
254
@@ -285,7 +287,8 @@ def __init__(
285
287
merges = None ,
286
288
sequence_length = None ,
287
289
add_prefix_space = False ,
288
- unsplittable_tokens = None ,
290
+ special_tokens = None ,
291
+ special_tokens_in_strings = False ,
289
292
dtype = "int32" ,
290
293
** kwargs ,
291
294
) -> None :
@@ -300,10 +303,12 @@ def __init__(
300
303
super ().__init__ (dtype = dtype , ** kwargs )
301
304
self .sequence_length = sequence_length
302
305
self .add_prefix_space = add_prefix_space
303
- self .unsplittable_tokens = unsplittable_tokens
304
- self ._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern (
305
- unsplittable_tokens
306
- )
306
+ self .special_tokens = special_tokens
307
+ self ._special_tokens_pattern = None
308
+ if special_tokens_in_strings :
309
+ self ._special_tokens_pattern = get_special_tokens_pattern (
310
+ special_tokens
311
+ )
307
312
308
313
# Create byte <=> unicode mapping. This is useful for handling
309
314
# whitespace tokens.
@@ -355,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
355
360
"token to int ids. Received: "
356
361
f"`type(vocabulary)={ type (vocabulary )} `."
357
362
)
363
+
364
+ # Check for special tokens in vocabulary.
365
+ if self .special_tokens is not None :
366
+ for token in self .special_tokens :
367
+ if token not in self .get_vocabulary ():
368
+ raise ValueError (
369
+ f"Cannot find token `'{ token } '` in the provided "
370
+ f"`vocabulary`. Please provide `'{ token } '` in your"
371
+ "`vocabulary` or use a pretrained `vocabulary` name."
372
+ )
373
+
358
374
if isinstance (merges , str ):
359
375
with open (merges , encoding = "utf-8" ) as f :
360
376
self .merges = [bp .rstrip () for bp in f ]
@@ -367,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
367
383
)
368
384
369
385
self .cache = BytePairTokenizerCache ()
370
- if self .unsplittable_tokens :
386
+ if self .special_tokens and self . _special_tokens_pattern is not None :
371
387
# Put special tokens into cache, so it won't be further split and
372
388
# merged.
373
- self .cache .insert (
374
- self .unsplittable_tokens , self .unsplittable_tokens
375
- )
389
+ self .cache .insert (self .special_tokens , self .special_tokens )
376
390
377
391
# Create mapping between string tokens to int ids, and vice versa.
378
392
byte_pairs = [x [0 ] for x in self .vocabulary .items ()]
@@ -550,9 +564,7 @@ def tokenize(self, inputs):
550
564
if scalar_input :
551
565
inputs = tf .expand_dims (inputs , 0 )
552
566
553
- raw_tokens = split_strings_for_bpe (
554
- inputs , self ._unsplittable_tokens_pattern
555
- )
567
+ raw_tokens = split_strings_for_bpe (inputs , self ._special_tokens_pattern )
556
568
token_row_splits = raw_tokens .row_splits
557
569
flat_tokens = raw_tokens .flat_values
558
570
@@ -646,7 +658,7 @@ def get_config(self):
646
658
{
647
659
"sequence_length" : self .sequence_length ,
648
660
"add_prefix_space" : self .add_prefix_space ,
649
- "unsplittable_tokens " : self .unsplittable_tokens ,
661
+ "special_tokens " : self .special_tokens ,
650
662
}
651
663
)
652
- return config
664
+ return config
0 commit comments