Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

[WIP] add character-level tokenizer #1515

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 123 additions & 9 deletions src/gluonnlp/data/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@
"""Tokenizers."""
__all__ = ['WhitespaceTokenizer', 'SpacyTokenizer', 'JiebaTokenizer', 'MosesTokenizer',
'SubwordNMTTokenizer', 'YTTMTokenizer', 'SentencepieceTokenizer',
'HuggingFaceBPETokenizer', 'HuggingFaceByteBPETokenizer',
'HuggingFaceWordPieceTokenizer',
'create', 'create_with_json', 'list_all']
'HuggingFaceBPETokenizer', 'HuggingFaceByteBPETokenizer', 'HuggingFaceWordPieceTokenizer',
'CharTokenizer', 'create', 'create_with_json', 'list_all']

from typing import List, Tuple, Union, Optional
import os
import json
from collections import OrderedDict
import abc
import sys
import json
import warnings
import itertools
from typing import NewType
import sacremoses
import jieba
import unicodedata
from uuid import uuid4
from typing import List, Tuple, Union, NewType, Optional
from collections import OrderedDict

import jieba
import sacremoses

from .vocab import Vocab
from ..registry import TOKENIZER_REGISTRY
from ..utils.lazy_imports import try_import_subword_nmt, \
Expand Down Expand Up @@ -1731,3 +1732,116 @@ def create_with_json(name: str, json_str: str) -> BaseTokenizer:

def list_all():
return TOKENIZER_REGISTRY.list_keys()


@TOKENIZER_REGISTRY.register('char')
class CharTokenizer(BaseTokenizerWithVocab):
def __init__(self, vocab: Optional[Vocab] = None,
unk_token: Optional[str] = Vocab.UNK_TOKEN,
lowercase: bool = False,
unicode_normalizer: Optional[str] = None):
self._vocab = vocab
self._unk_token = unk_token
self._lowercase = lowercase
self._unicode_normalizer = unicode_normalizer
if unicode_normalizer is not None:
assert unicode_normalizer in ['NFC', 'NFKC', 'NFD', 'NFKD']

def process_text(self, text):
if self._lowercase:
text = text.lower()
if self._unicode_normalizer:
text = unicodedata.normalize(self._unicode_normalizer, text)
return text

def text_to_list(self, text):
sequence = list(self.process_text(text))
for index in range(len(sequence)):
if sequence[index] not in self._vocab.all_tokens:
sequence[index] = self._unk_token
return sequence

def encode_with_offsets(self, sentences, output_type=str):
raise NotImplementedError('We cannot obtain the original offsets for CharTokenizer.')

def encode(self, sentences, output_type=str):
is_multiple_sentences = isinstance(sentences, list)
if not is_multiple_sentences:
sentences = [sentences]

if output_type is str:
tokens = [self.text_to_list(sentence) for sentence in sentences]
elif output_type is int:
if self._vocab is None:
raise ValueError(_encode_no_vocab_err_msg())
tokens = [self._vocab[self.text_to_list(sentence)] for sentence in sentences]
else:
raise NotImplementedError
if is_multiple_sentences:
return tokens
else:
return tokens[0]

def decode(self, tokens):
pass

@property
def vocab(self):
return self._vocab

def set_vocab(self, vocab):
"""Set the vocabulary of the tokenizer

Parameters
----------
vocab
"""
self._vocab = vocab

def set_vocab_from_text(self, sentences, vocab_size=None):
"""Set the vocabulary of the tokenizer from list of sentences

Parameters
----------
sentences
"""
is_multiple_sentences = isinstance(sentences, list)
if not is_multiple_sentences:
sentences = [sentences]
word_counts = OrderedDict()
for sentence in sentences:
sentence = self.process_text(sentence)
for word in sentence:
if word in word_counts:
word_counts[word] += 1
else:
word_counts[word] = 1
word_lists = list(word_counts.items())
word_lists.sort(key=lambda x: x[1], reverse=True)
if vocab_size:
word_lists = word_lists[:vocab_size]
if self._unk_token is None:
all_tokens = []
else:
all_tokens = [self._unk_token]
all_tokens.extend(wc[0] for wc in word_lists)
self._vocab = Vocab(all_tokens, unk_token=self._unk_token)

def set_lowercase(self, lowercase: float):
self._lowercase = lowercase
# no need to rebuild the tokenizer

@property
def lowercase(self):
return self._lowercase

def __repr__(self):
ret = '{}(\n' \
' model_path = {}\n' \
' lowercase = {}, vocab = {}\n' \
' unicode_normalizer = {}\n' \
')'.format(self.__class__.__name__,
os.path.realpath(self._model_path),
self._lowercase, self._vocab,
self._unicode_normalizer)
return ret
17 changes: 16 additions & 1 deletion tests/test_data_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import tempfile
from gluonnlp.data.tokenizers import WhitespaceTokenizer, MosesTokenizer, JiebaTokenizer,\
SpacyTokenizer, SubwordNMTTokenizer, YTTMTokenizer, SentencepieceTokenizer, \
HuggingFaceBPETokenizer, HuggingFaceByteBPETokenizer, HuggingFaceWordPieceTokenizer
HuggingFaceBPETokenizer, HuggingFaceByteBPETokenizer, HuggingFaceWordPieceTokenizer, CharTokenizer
from gluonnlp.base import get_repo_url
from gluonnlp.data import Vocab
from gluonnlp.utils.misc import download
Expand Down Expand Up @@ -619,3 +619,18 @@ def test_huggingface_wordpiece_tokenizer():

os.remove(vocab_path)
os.remove(hf_vocab_path)


def test_charlevel_tokenizer():
with tempfile.TemporaryDirectory() as dir_path:
tokenizer = CharTokenizer(lowercase=False)
sentences = ["hello , y ' all ! how are you ?"]
tokenizer.set_vocab_from_text(sentences, vocab_size=10)
tokenizer.vocab.all_tokens == ['<unk>', ' ', 'l', 'o', 'h', 'e', 'y', 'a', ',', "'", '!']
gt_tokenized = [['<unk>', 'e', 'l', 'l', 'o', ',', ' ', 'y', "'", 'a', 'l', 'l', '!', ' ',
'<unk>', 'o', '<unk>', ' ', 'a', '<unk>', 'e', ' ', 'y', 'o', '<unk>', ' ',
'<unk>', ' ', '<unk>', ' ', '<unk>', ' ', '<unk>', ' ', '<unk>']]
gt_decode = ["hello, y'all! how are you?"]
verify_encode_token(tokenizer, SUBWORD_TEST_SAMPLES[0:1], gt_tokenized)
verify_pickleble(tokenizer, CharTokenizer)
# TODO(zheyuye), test decode