Skip to content

Commit 74c3557

Browse files
committed
Rename unsplittable to special
1 parent c231372 commit 74c3557

File tree

11 files changed

+68
-71
lines changed

11 files changed

+68
-71
lines changed

keras_nlp/models/bart/bart_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,12 @@ def __init__(
9090
super().__init__(
9191
vocabulary=vocabulary,
9292
merges=merges,
93-
unsplittable_tokens=[
93+
special_tokens=[
9494
self.start_token,
9595
self.pad_token,
9696
self.end_token,
9797
],
98-
unsplittable_tokens_in_strings=special_tokens_in_strings,
98+
special_tokens_in_strings=special_tokens_in_strings,
9999
**kwargs,
100100
)
101101

@@ -113,5 +113,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
113113

114114
def get_config(self):
115115
config = super().get_config()
116-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
116+
del config["special_tokens"] # Not configurable; set in __init__.
117117
return config

keras_nlp/models/bloom/bloom_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def __init__(
8282
super().__init__(
8383
vocabulary=vocabulary,
8484
merges=merges,
85-
unsplittable_tokens=[
85+
special_tokens=[
8686
self.start_token,
8787
self.end_token,
8888
self.pad_token,
8989
],
90-
unsplittable_tokens_in_strings=special_tokens_in_strings,
90+
special_tokens_in_strings=special_tokens_in_strings,
9191
**kwargs,
9292
)
9393

@@ -105,5 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
105105

106106
def get_config(self):
107107
config = super().get_config()
108-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
108+
del config["special_tokens"] # Not configurable; set in __init__.
109109
return config

keras_nlp/models/falcon/falcon_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def __init__(
8181
super().__init__(
8282
vocabulary=vocabulary,
8383
merges=merges,
84-
unsplittable_tokens=[self.end_token],
85-
unsplittable_tokens_in_strings=special_tokens_in_strings,
84+
special_tokens=[self.end_token],
85+
special_tokens_in_strings=special_tokens_in_strings,
8686
**kwargs,
8787
)
8888

@@ -100,5 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
100100

101101
def get_config(self):
102102
config = super().get_config()
103-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
103+
del config["special_tokens"] # Not configurable; set in __init__.
104104
return config

keras_nlp/models/falcon/falcon_tokenizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setUp(self):
2929
"vocabulary": self.vocab,
3030
"merges": self.merges,
3131
"special_tokens_in_strings": True,
32-
}
32+
}
3333
self.input_data = [
3434
" airplane at airport<|endoftext|>",
3535
" airplane airport",

keras_nlp/models/gpt2/gpt2_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def __init__(
8181
super().__init__(
8282
vocabulary=vocabulary,
8383
merges=merges,
84-
unsplittable_tokens=[self.end_token],
85-
unsplittable_tokens_in_strings=special_tokens_in_strings,
84+
special_tokens=[self.end_token],
85+
special_tokens_in_strings=special_tokens_in_strings,
8686
**kwargs,
8787
)
8888

@@ -100,5 +100,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
100100

101101
def get_config(self):
102102
config = super().get_config()
103-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
103+
del config["special_tokens"] # Not configurable; set in __init__.
104104
return config

keras_nlp/models/gpt_neo_x/gpt_neo_x_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def __init__(
5959
super().__init__(
6060
vocabulary=vocabulary,
6161
merges=merges,
62-
unsplittable_tokens=[self.end_token],
63-
unsplittable_tokens_in_strings=special_tokens_in_strings,
62+
special_tokens=[self.end_token],
63+
special_tokens_in_strings=special_tokens_in_strings,
6464
**kwargs,
6565
)
6666

@@ -78,5 +78,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
7878

7979
def get_config(self):
8080
config = super().get_config()
81-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
81+
del config["special_tokens"] # Not configurable; set in __init__.
8282
return config

keras_nlp/models/opt/opt_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def __init__(
8282
super().__init__(
8383
vocabulary=vocabulary,
8484
merges=merges,
85-
unsplittable_tokens=[
85+
special_tokens=[
8686
self.start_token,
8787
self.pad_token,
8888
self.end_token,
8989
],
90-
unsplittable_tokens_in_strings=special_tokens_in_strings,
90+
special_tokens_in_strings=special_tokens_in_strings,
9191
**kwargs,
9292
)
9393

@@ -105,5 +105,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
105105

106106
def get_config(self):
107107
config = super().get_config()
108-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
108+
del config["special_tokens"] # Not configurable; set in __init__.
109109
return config

keras_nlp/models/roberta/roberta_tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def __init__(
9090
super().__init__(
9191
vocabulary=vocabulary,
9292
merges=merges,
93-
unsplittable_tokens=[
93+
special_tokens=[
9494
self.start_token,
9595
self.pad_token,
9696
self.end_token,
9797
self.mask_token,
9898
],
99-
unsplittable_tokens_in_strings=special_tokens_in_strings,
99+
special_tokens_in_strings=special_tokens_in_strings,
100100
**kwargs,
101101
)
102102

@@ -116,5 +116,5 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
116116

117117
def get_config(self):
118118
config = super().get_config()
119-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
119+
del config["special_tokens"] # Not configurable; set in __init__.
120120
return config

keras_nlp/models/whisper/whisper_tokenizer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(
9999
self.translate_token_id = special_tokens[self.translate_token]
100100
self.transcribe_token_id = special_tokens[self.transcribe_token]
101101

102-
self.special_tokens = special_tokens
102+
# Underscore to distinguish it from `self.special_tokens` in base class.
103+
self._special_tokens = special_tokens
103104
self.language_tokens = language_tokens
104105

105106
# TODO: Add language tokens to `unsplittable_tokens` once we figure
@@ -109,8 +110,8 @@ def __init__(
109110
super().__init__(
110111
vocabulary=vocabulary,
111112
merges=merges,
112-
unsplittable_tokens=unsplittable_tokens,
113-
unsplittable_tokens_in_strings=special_tokens_in_strings,
113+
special_tokens=unsplittable_tokens,
114+
special_tokens_in_strings=special_tokens_in_strings,
114115
**kwargs,
115116
)
116117

@@ -146,18 +147,18 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
146147
self.translate_token,
147148
self.transcribe_token,
148149
]:
149-
vocabulary[token] = self.special_tokens[token]
150+
vocabulary[token] = self._special_tokens[token]
150151
else:
151152
self._initial_vocabulary = None
152153

153154
super().set_vocabulary_and_merges(vocabulary, merges)
154155

155156
def get_config(self):
156157
config = super().get_config()
157-
del config["unsplittable_tokens"] # Not configurable; set in __init__.
158+
del config["special_tokens"] # Not configurable; set in __init__.
158159
config.update(
159160
{
160-
"special_tokens": self.special_tokens,
161+
"special_tokens": self._special_tokens,
161162
"language_tokens": self.language_tokens,
162163
}
163164
)

keras_nlp/tokenizers/byte_pair_tokenizer.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@
5959
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
6060

6161

62-
def get_unsplittable_tokens_pattern(unsplittable_tokens):
63-
if unsplittable_tokens is None or len(unsplittable_tokens) == 0:
62+
def get_special_tokens_pattern(special_tokens):
63+
if special_tokens is None or len(special_tokens) == 0:
6464
return None
65-
return r"|".join([re.escape(token) for token in unsplittable_tokens])
65+
return r"|".join([re.escape(token) for token in special_tokens])
6666

6767

6868
def bytes_to_unicode():
@@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
9797
return result
9898

9999

100-
def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
100+
def split_strings_for_bpe(inputs, special_tokens_pattern=None):
101101
# We need to recreate the exact behavior of token presplitting in the
102102
# original gpt2 tokenizer which uses a lookahead. As re2 does not
103103
# support lookahead match, we are using an alternative insert a special
@@ -110,23 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
110110
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
111111
)
112112

113-
if unsplittable_tokens_pattern is not None:
114-
# First split the unsplittable tokens from the input.
113+
if special_tokens_pattern is not None:
114+
# First split the special tokens from the input.
115115
raw_tokens = tf_text.regex_split(
116-
inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern
116+
inputs, special_tokens_pattern, special_tokens_pattern
117117
)
118-
# Then split using both `unsplittable_tokens_pattern` and
118+
# Then split using both `special_tokens_pattern` and
119119
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
120-
# affecting the unsplittable tokens.
121-
# We split unsplittable tokens first then apply this split instead of
120+
# affecting the special tokens.
121+
# We split special tokens first then apply this split instead of
122122
# applying this split directly, because otherwise we will not split
123-
# unsplittable tokens from inputs properly, because of this pattern
123+
# special tokens from inputs properly, because of this pattern
124124
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
125125
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
126126
raw_tokens = tf_text.regex_split(
127127
raw_tokens,
128-
r"|".join([unsplittable_tokens_pattern, SPLIT_PATTERN_1]),
129-
r"|".join([unsplittable_tokens_pattern, SPLIT_PATTERN_1]),
128+
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
129+
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
130130
)
131131
raw_tokens = raw_tokens.merge_dims(-2, -1)
132132
else:
@@ -238,16 +238,16 @@ class BytePairTokenizer(tokenizer.Tokenizer):
238238
a prefix space to the first word will cause it to be tokenized
239239
equivalently to all subsequent words in the sequence.
240240
Defaults to `False`.
241-
unsplittable_tokens: list. A list of unsplittable tokens. when
242-
`unsplittable_tokens_in_strings` is set to `True`, unsplittable
241+
special_tokens: list. A list of special tokens. when
242+
`special_tokens_in_strings` is set to `True`, special
243243
tokens will never be split during the word-level splitting applied
244244
before the byte-pair encoding. This can be used to ensure special
245245
tokens map to unique indices in the vocabulary, even if these
246-
unsplittable tokens contain splittable characters such as
247-
punctuation. Unsplittable tokens must still be included in
246+
special tokens contain splittable characters such as
247+
punctuation. special tokens must still be included in
248248
`vocabulary`. Defaults to `None`.
249-
unsplittable_tokens_in_strings: bool. To indicate if the tokenizer
250-
should expect unsplittable tokens in input strings that should be
249+
special_tokens_in_strings: bool. To indicate if the tokenizer
250+
should expect special tokens in input strings that should be
251251
tokenized and mapped correctly to their ids. Defaults to False.
252252
253253
Examples:
@@ -287,8 +287,8 @@ def __init__(
287287
merges=None,
288288
sequence_length=None,
289289
add_prefix_space=False,
290-
unsplittable_tokens=None,
291-
unsplittable_tokens_in_strings=False,
290+
special_tokens=None,
291+
special_tokens_in_strings=False,
292292
dtype="int32",
293293
**kwargs,
294294
) -> None:
@@ -303,11 +303,11 @@ def __init__(
303303
super().__init__(dtype=dtype, **kwargs)
304304
self.sequence_length = sequence_length
305305
self.add_prefix_space = add_prefix_space
306-
self.unsplittable_tokens = unsplittable_tokens
307-
self._unsplittable_tokens_pattern = None
308-
if unsplittable_tokens_in_strings:
309-
self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern(
310-
unsplittable_tokens
306+
self.special_tokens = special_tokens
307+
self._special_tokens_pattern = None
308+
if special_tokens_in_strings:
309+
self._special_tokens_pattern = get_special_tokens_pattern(
310+
special_tokens
311311
)
312312

313313
# Create byte <=> unicode mapping. This is useful for handling
@@ -362,8 +362,8 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
362362
)
363363

364364
# Check for special tokens in vocabulary.
365-
if self.unsplittable_tokens is not None:
366-
for token in self.unsplittable_tokens:
365+
if self.special_tokens is not None:
366+
for token in self.special_tokens:
367367
if token not in self.get_vocabulary():
368368
raise ValueError(
369369
f"Cannot find token `'{token}'` in the provided "
@@ -383,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
383383
)
384384

385385
self.cache = BytePairTokenizerCache()
386-
if self.unsplittable_tokens:
386+
if self.special_tokens and self._special_tokens_pattern is not None:
387387
# Put special tokens into cache, so it won't be further split and
388388
# merged.
389-
self.cache.insert(
390-
self.unsplittable_tokens, self.unsplittable_tokens
391-
)
389+
self.cache.insert(self.special_tokens, self.special_tokens)
392390

393391
# Create mapping between string tokens to int ids, and vice versa.
394392
byte_pairs = [x[0] for x in self.vocabulary.items()]
@@ -566,9 +564,7 @@ def tokenize(self, inputs):
566564
if scalar_input:
567565
inputs = tf.expand_dims(inputs, 0)
568566

569-
raw_tokens = split_strings_for_bpe(
570-
inputs, self._unsplittable_tokens_pattern
571-
)
567+
raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern)
572568
token_row_splits = raw_tokens.row_splits
573569
flat_tokens = raw_tokens.flat_values
574570

@@ -662,7 +658,7 @@ def get_config(self):
662658
{
663659
"sequence_length": self.sequence_length,
664660
"add_prefix_space": self.add_prefix_space,
665-
"unsplittable_tokens": self.unsplittable_tokens,
661+
"special_tokens": self.special_tokens,
666662
}
667663
)
668664
return config

0 commit comments

Comments
 (0)