Skip to content

Commit d0ff826

Browse files
committed
Add special_tokens_in_strings to byte_pair_tokenizer
1 parent 29873a9 commit d0ff826

File tree

2 files changed

+62
-40
lines changed

2 files changed

+62
-40
lines changed

keras_nlp/tokenizers/byte_pair_tokenizer.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@
5959
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
6060

6161

62-
def get_unsplittable_tokens_pattern(unsplittable_tokens):
63-
if unsplittable_tokens is None or len(unsplittable_tokens) == 0:
62+
def get_special_tokens_pattern(special_tokens):
63+
if special_tokens is None or len(special_tokens) == 0:
6464
return None
65-
return r"|".join([re.escape(token) for token in unsplittable_tokens])
65+
return r"|".join([re.escape(token) for token in special_tokens])
6666

6767

6868
def bytes_to_unicode():
@@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
9797
return result
9898

9999

100-
def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
100+
def split_strings_for_bpe(inputs, special_tokens_pattern=None):
101101
# We need to recreate the exact behavior of token presplitting in the
102102
# original gpt2 tokenizer which uses a lookahead. As re2 does not
103103
# support lookahead match, we are using an alternative insert a special
@@ -110,26 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
110110
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
111111
)
112112

113-
if unsplittable_tokens_pattern is not None:
114-
# First split the unsplittable tokens from the input.
113+
if special_tokens_pattern is not None:
114+
# First split the special tokens from the input.
115115
raw_tokens = tf_text.regex_split(
116-
inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern
116+
inputs, special_tokens_pattern, special_tokens_pattern
117117
)
118-
split_pattern_1_with_unsplittable_tokens = r"|".join(
119-
[unsplittable_tokens_pattern, SPLIT_PATTERN_1]
120-
)
121-
# Then split using both `unsplittable_tokens_pattern` and
118+
# Then split using both `special_tokens_pattern` and
122119
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
123-
# affecting the unsplittable tokens.
124-
# We split unsplittable tokens first then apply this split instead of
120+
# affecting the special tokens.
121+
# We split special tokens first then apply this split instead of
125122
# applying this split directly, because otherwise we will not split
126-
# unsplittable tokens from inputs properly, because of this pattern
123+
# special tokens from inputs properly, because of this pattern
127124
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
128125
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
129126
raw_tokens = tf_text.regex_split(
130127
raw_tokens,
131-
split_pattern_1_with_unsplittable_tokens,
132-
split_pattern_1_with_unsplittable_tokens,
128+
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
129+
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
133130
)
134131
raw_tokens = raw_tokens.merge_dims(-2, -1)
135132
else:
@@ -241,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
241238
a prefix space to the first word will cause it to be tokenized
242239
equivalently to all subsequent words in the sequence.
243240
Defaults to `False`.
244-
unsplittable_tokens: list. A list of strings that will
245-
never be split during the word-level splitting applied before the
246-
byte-pair encoding. This can be used to ensure special tokens map to
247-
unique indices in the vocabulary, even if these special tokens
248-
contain splittable characters such as punctuation. Special tokens
249-
must still be included in `vocabulary`. Defaults to `None`.
241+
special_tokens: list. A list of special tokens. when
242+
`special_tokens_in_strings` is set to `True`, special
243+
tokens will never be split during the word-level splitting applied
244+
before the byte-pair encoding. This can be used to ensure special
245+
tokens map to unique indices in the vocabulary, even if these
246+
special tokens contain splittable characters such as
247+
punctuation. special tokens must still be included in
248+
`vocabulary`. Defaults to `None`.
249+
special_tokens_in_strings: bool. To indicate if the tokenizer
250+
should expect special tokens in input strings that should be
251+
tokenized and mapped correctly to their ids. Defaults to False.
250252
251253
Examples:
252254
@@ -285,7 +287,8 @@ def __init__(
285287
merges=None,
286288
sequence_length=None,
287289
add_prefix_space=False,
288-
unsplittable_tokens=None,
290+
special_tokens=None,
291+
special_tokens_in_strings=False,
289292
dtype="int32",
290293
**kwargs,
291294
) -> None:
@@ -300,10 +303,12 @@ def __init__(
300303
super().__init__(dtype=dtype, **kwargs)
301304
self.sequence_length = sequence_length
302305
self.add_prefix_space = add_prefix_space
303-
self.unsplittable_tokens = unsplittable_tokens
304-
self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern(
305-
unsplittable_tokens
306-
)
306+
self.special_tokens = special_tokens
307+
self._special_tokens_pattern = None
308+
if special_tokens_in_strings:
309+
self._special_tokens_pattern = get_special_tokens_pattern(
310+
special_tokens
311+
)
307312

308313
# Create byte <=> unicode mapping. This is useful for handling
309314
# whitespace tokens.
@@ -355,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
355360
"token to int ids. Received: "
356361
f"`type(vocabulary)={type(vocabulary)}`."
357362
)
363+
364+
# Check for special tokens in vocabulary.
365+
if self.special_tokens is not None:
366+
for token in self.special_tokens:
367+
if token not in self.get_vocabulary():
368+
raise ValueError(
369+
f"Cannot find token `'{token}'` in the provided "
370+
f"`vocabulary`. Please provide `'{token}'` in your"
371+
"`vocabulary` or use a pretrained `vocabulary` name."
372+
)
373+
358374
if isinstance(merges, str):
359375
with open(merges, encoding="utf-8") as f:
360376
self.merges = [bp.rstrip() for bp in f]
@@ -367,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
367383
)
368384

369385
self.cache = BytePairTokenizerCache()
370-
if self.unsplittable_tokens:
386+
if self.special_tokens and self._special_tokens_pattern is not None:
371387
# Put special tokens into cache, so it won't be further split and
372388
# merged.
373-
self.cache.insert(
374-
self.unsplittable_tokens, self.unsplittable_tokens
375-
)
389+
self.cache.insert(self.special_tokens, self.special_tokens)
376390

377391
# Create mapping between string tokens to int ids, and vice versa.
378392
byte_pairs = [x[0] for x in self.vocabulary.items()]
@@ -550,9 +564,7 @@ def tokenize(self, inputs):
550564
if scalar_input:
551565
inputs = tf.expand_dims(inputs, 0)
552566

553-
raw_tokens = split_strings_for_bpe(
554-
inputs, self._unsplittable_tokens_pattern
555-
)
567+
raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern)
556568
token_row_splits = raw_tokens.row_splits
557569
flat_tokens = raw_tokens.flat_values
558570

@@ -646,7 +658,7 @@ def get_config(self):
646658
{
647659
"sequence_length": self.sequence_length,
648660
"add_prefix_space": self.add_prefix_space,
649-
"unsplittable_tokens": self.unsplittable_tokens,
661+
"special_tokens": self.special_tokens,
650662
}
651663
)
652-
return config
664+
return config

keras_nlp/tokenizers/byte_pair_tokenizer_test.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,30 +67,40 @@ def test_tokenize_with_special_tokens(self):
6767
tokenizer = BytePairTokenizer(
6868
vocabulary=vocab,
6969
merges=merges,
70-
unsplittable_tokens=["s", "p"],
70+
special_tokens=["s", "p"],
71+
special_tokens_in_strings=True,
7172
)
7273
output = tokenizer("sp")
7374
self.assertAllEqual(output, [1, 2])
7475

75-
# If not setting special tokens, "sp" is one token.
76+
# If not special_tokens_in_strings is `True`, "sp" is one token.
7677
tokenizer = BytePairTokenizer(
7778
vocabulary=vocab,
7879
merges=merges,
80+
special_tokens=["s", "p"],
7981
)
8082
output = tokenizer("sp")
8183
self.assertAllEqual(output, [0])
8284

85+
# test real wolrd special tokens. e. g. <s> and </s>
8386
vocab = {"<s>": 0, "</s>": 1, "a": 2, "Ġquick": 3, "Ġfox": 4}
8487
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
8588
merges += ["Ġ f", "o x", "Ġf ox"]
8689
tokenizer = BytePairTokenizer(
8790
vocabulary=vocab,
8891
merges=merges,
89-
unsplittable_tokens=["<s>", "</s>"],
92+
special_tokens=["<s>", "</s>"],
93+
special_tokens_in_strings=True,
9094
)
9195
output = tokenizer("<s>a quick fox</s>")
9296
self.assertAllEqual(output, [0, 2, 3, 4, 1])
9397

98+
def test_errors_missing_special_tokens(self):
99+
with self.assertRaises(ValueError):
100+
BytePairTokenizer(
101+
vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"]
102+
)
103+
94104
def test_tokenize_prefix_space(self):
95105
input_data = ["brown.", "black."]
96106
tokenizer = BytePairTokenizer(
@@ -181,4 +191,4 @@ def test_config(self):
181191
self.assertAllEqual(
182192
self.tokenizer(input_data),
183193
cloned_tokenizer(input_data),
184-
)
194+
)

0 commit comments

Comments
 (0)