Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 1f9ad44

Browse files
authored
Horovod support for pretraining and fune-tuning squad (#1276)
* fix roberta * fix xlmr * fix token_ids * fix * use_segmentation * fix roberta * update * fix * fix mobilebert * repeat * repeat for pretraining * revise * revise train_transformer * upload gluon_electra_small_owt * fix openwebtext * fix wiki * fix bookcorpus * multiprocessing for wiki * update * rename * index_update * topk * revise * layer-wise decay * fix mobilebert * try * update hyper-parameters of adamw * fix roberta * clip_grad_global_norm with zeros max_grad_norm * fix ModelForQABasic * multiply_grads * remove multiply_grads * fix * horovod for squad * update * inference without horovod * fix * update * re-upload roberta * fix get_pretrained * re-upload xlmr * update testings * tiny update on run_squad * test * lowercase * CharTokenizer * Squashed commit of the following: commit 35a586676036f627bffd0d3c753c6cd0a70d63cf Author: ZheyuYe <[email protected]> Date: Fri Jul 17 10:10:14 2020 +0800 Squashed commit of the following: commit 673344d Author: ZheyuYe <[email protected]> Date: Wed Jul 15 22:43:07 2020 +0800 CharTokenizer commit 8dabfd6 Author: ZheyuYe <[email protected]> Date: Wed Jul 15 15:47:24 2020 +0800 lowercase commit f5c94a6 Author: ZheyuYe <[email protected]> Date: Tue Jul 14 17:45:28 2020 +0800 test commit dc55fc9 Author: ZheyuYe <[email protected]> Date: Tue Jul 14 05:45:01 2020 +0800 tiny update on run_squad commit 4defc7a Author: ZheyuYe <[email protected]> Date: Mon Jul 13 23:18:08 2020 +0800 update testings commit 2719e81 Author: ZheyuYe <[email protected]> Date: Mon Jul 13 23:08:32 2020 +0800 re-upload xlmr commit cd0509d Author: ZheyuYe <[email protected]> Date: Mon Jul 13 22:30:47 2020 +0800 fix get_pretrained commit 8ed8a72 Author: ZheyuYe <[email protected]> Date: Mon Jul 13 22:28:13 2020 +0800 re-upload roberta commit 5811d40 Author: ZheyuYe <[email protected]> Date: Mon Jul 13 18:27:23 2020 +0800 update commit 44a09a3 Author: ZheyuYe <[email protected]> Date: Sat Jul 11 15:06:33 2020 +0800 fix commit 4074a26 Author: ZheyuYe <[email protected]> Date: Fri Jul 10 16:08:49 2020 +0800 inference without horovod commit 31cb953 Author: ZheyuYe <[email protected]> Date: Thu Jul 9 18:41:55 2020 +0800 update commit 838be2a Author: ZheyuYe <[email protected]> Date: Thu Jul 9 15:14:39 2020 +0800 horovod for squad commit 1d374a2 Author: ZheyuYe <[email protected]> Date: Thu Jul 9 12:09:19 2020 +0800 fix commit e4fba39 Author: ZheyuYe <[email protected]> Date: Thu Jul 9 10:35:08 2020 +0800 remove multiply_grads commit 007f07e Author: ZheyuYe <[email protected]> Date: Tue Jul 7 11:26:38 2020 +0800 multiply_grads commit b8c85bb Author: ZheyuYe <[email protected]> Date: Mon Jul 6 12:28:56 2020 +0800 fix ModelForQABasic commit 0e13a58 Author: ZheyuYe <[email protected]> Date: Sat Jul 4 18:42:12 2020 +0800 clip_grad_global_norm with zeros max_grad_norm commit bd270f2 Author: ZheyuYe <[email protected]> Date: Fri Jul 3 20:21:31 2020 +0800 fix roberta commit 4fc564c Author: ZheyuYe <[email protected]> Date: Fri Jul 3 19:36:08 2020 +0800 update hyper-parameters of adamw commit 59cffbf Author: ZheyuYe <[email protected]> Date: Fri Jul 3 16:25:46 2020 +0800 try commit a84f782 Author: ZheyuYe <[email protected]> Date: Thu Jul 2 20:39:03 2020 +0800 fix mobilebert commit 4bc3a96 Author: ZheyuYe <[email protected]> Date: Thu Jul 2 11:14:39 2020 +0800 layer-wise decay commit 07186d5 Author: ZheyuYe <[email protected]> Date: Thu Jul 2 02:14:43 2020 +0800 revise commit a5a6475 Author: ZheyuYe <[email protected]> Date: Wed Jul 1 19:50:20 2020 +0800 topk commit 34ee884 Author: ZheyuYe <[email protected]> Date: Wed Jul 1 19:25:09 2020 +0800 index_update commit 74178e2 Author: ZheyuYe <[email protected]> Date: Wed Jul 1 00:48:32 2020 +0800 rename commit fa011aa Author: ZheyuYe <[email protected]> Date: Tue Jun 30 23:40:28 2020 +0800 update commit 402d625 Author: ZheyuYe <[email protected]> Date: Tue Jun 30 21:40:30 2020 +0800 multiprocessing for wiki commit ddbde75 Author: ZheyuYe <[email protected]> Date: Tue Jun 30 20:41:35 2020 +0800 fix bookcorpus commit 6cc5ccd Author: ZheyuYe <[email protected]> Date: Tue Jun 30 16:39:12 2020 +0800 fix wiki commit 9773efd Author: ZheyuYe <[email protected]> Date: Tue Jun 30 15:52:13 2020 +0800 fix openwebtext commit 1fb8eb8 Author: ZheyuYe <[email protected]> Date: Mon Jun 29 19:51:25 2020 +0800 upload gluon_electra_small_owt commit ca83fac Author: ZheyuYe <[email protected]> Date: Mon Jun 29 18:09:48 2020 +0800 revise train_transformer commit 1450f5c Author: ZheyuYe <[email protected]> Date: Mon Jun 29 18:07:04 2020 +0800 revise commit b460bbe Author: ZheyuYe <[email protected]> Date: Mon Jun 29 17:24:00 2020 +0800 repeat for pretraining commit 8ee381b Author: ZheyuYe <[email protected]> Date: Mon Jun 29 17:06:43 2020 +0800 repeat commit aea936f Author: ZheyuYe <[email protected]> Date: Mon Jun 29 16:39:22 2020 +0800 fix mobilebert commit eead164 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 18:44:28 2020 +0800 fix commit 8645115 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 17:27:43 2020 +0800 update commit 2b7f7a3 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 17:18:00 2020 +0800 fix roberta commit 86702fe Author: ZheyuYe <[email protected]> Date: Sun Jun 28 16:27:43 2020 +0800 use_segmentation commit 6d03d7a Author: ZheyuYe <[email protected]> Date: Sun Jun 28 15:52:40 2020 +0800 fix commit 5c0ca43 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 15:49:48 2020 +0800 fix token_ids commit ff7aae8 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 13:56:07 2020 +0800 fix xlmr commit 2070b86 Author: ZheyuYe <[email protected]> Date: Sun Jun 28 13:54:26 2020 +0800 fix roberta commit 70a1887 Author: Leonard Lausen <[email protected]> Date: Fri Jul 17 00:07:08 2020 +0000 Update for Block API (#1261) - Remove params and prefix arguments for MXNet 2 and update parameter sharing implementation - Remove Block.name_scope() for MXNet 2 - Remove self.params.get() and self.params.get_constant() commit ea9152b Author: Xingjian Shi <[email protected]> Date: Thu Jul 16 15:42:04 2020 -0700 Fixes to make the CI more stable (#1265) * Some fixes to make the CI more stable * add retries * Update tokenizers.py commit a646c34 Author: ht <[email protected]> Date: Sun Jul 12 02:49:53 2020 +0800 [FEATURE] update backtranslation and add multinomial sampler (#1259) * back translation bash * split "lang-pair" para in clean_tok_para_corpus * added clean_tok_mono_corpus * fix * add num_process para * fix * fix * add yml * rm yml * update cfg name * update evaluate * added max_update / save_interval_update params * fix * fix * multi gpu inference * fix * update * update multi gpu inference * fix * fix * split evaluate and parallel infer * fix * test * fix * update * add comments * fix * remove todo comment * revert remove todo comment * raw lines remove duplicated '\n' * update multinomaial sampler * fix * fix * fix * fix * sampling * update script * fix * add test_case with k > 1 in topk sampling * fix multinomial sampler * update docs * comments situation eos_id = None * fix Co-authored-by: Hu <[email protected]> commit 83e1f13 Author: Leonard Lausen <[email protected]> Date: Thu Jul 9 20:57:55 2020 -0700 Use Amazon S3 Transfer Acceleration (#1260) commit cd48efd Author: Leonard Lausen <[email protected]> Date: Tue Jul 7 17:39:42 2020 -0700 Update codecov action to handle different OS and Python versions (#1254) codecov/codecov-action#80 (comment) commit 689eba9 Author: Sheng Zha <[email protected]> Date: Tue Jul 7 09:55:34 2020 -0700 [CI] AWS batch job tool for GluonNLP (Part I) (#1251) * AWS batch job tool for GluonNLP * limit range Co-authored-by: Xingjian Shi <[email protected]> commit e06ff01 Author: Leonard Lausen <[email protected]> Date: Tue Jul 7 08:36:24 2020 -0700 Pin mxnet version range on CI (#1257) * frozen_params * remove conversion to a sperate pr * fix * fix * update * test * revise * update performance numbers * update apply_layerwisw_decay * use shuffle * fix mobilebert * fix vocab_file
1 parent 2294421 commit 1f9ad44

File tree

25 files changed

+815
-567
lines changed

25 files changed

+815
-567
lines changed

scripts/datasets/__main__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from .general_nlp_benchmark import prepare_glue
88
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY
99

10-
10+
# TODO(zheyuye), lazy_import theses data parser functions and data main function
11+
# and their dependencies by a dictionary mapping the datasets names to the functions.
1112
def list_all_subcommands():
1213
out = []
1314
for key in DATA_PARSER_REGISTRY.list_keys():

scripts/datasets/pretrain_corpus/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,41 @@
22

33
We provide a series of shared scripts for downloading/preparing the text corpus for pretraining NLP models.
44
This helps create a unified text corpus for studying the performance of different pretraining algorithms.
5-
When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/),
6-
i.e., the dataset needs to be findable, accessible, interoperable, and reusable.
5+
When releasing the datasets, we follow the [FAIR principle](https://www.go-fair.org/fair-principles/),
6+
i.e., the dataset needs to be findable, accessible, interoperable, and reusable.
77

88
## BookCorpus
99
Unfortunately, we are unable to provide the original [Toronto BookCorpus dataset](https://yknzhu.wixsite.com/mbweb) due to licensing issues.
1010

1111
There are some open source efforts for reproducing the dataset, e.g.,
12-
using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view).
13-
12+
using [soskek/bookcorpus](https://github.com/soskek/bookcorpus) or directly downloading the [preprocessed version](https://drive.google.com/file/d/16KCjV9z_FHm8LgZw05RSuk4EsAWPOP_z/view).
13+
1414
Nevertheless, we utilize the [Project Gutenberg](https://www.gutenberg.org/) as an alternative to Toronto BookCorpus.
1515

16-
You can use the following command to download and prepare the Gutenberg dataset.
16+
You can use the following command to download and prepare the Gutenberg dataset.
1717

1818
```bash
1919
python prepare_bookcorpus.py --dataset gutenberg
2020
```
2121

22-
Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data.
22+
Also, you should follow the [license](https://www.gutenberg.org/wiki/Gutenberg:The_Project_Gutenberg_License) for using the data.
2323

2424
## Wikipedia
2525

2626
Please install [attardi/wikiextractor](https://github.com/attardi/wikiextractor) for preparing the data.
2727

28-
```
28+
```bash
2929
# Download
3030
python prepare_wikipedia.py --mode download --lang en --date latest -o ./
3131

3232
# Properly format the text files
3333
python prepare_wikipedia.py --mode format -i [path-to-wiki.xml.bz2] -o ./
3434

35+
```
36+
The process of downloading and formatting is time consuming, and we offer an alternative solution to download the prepared raw text file from S3 bucket. This raw text file is in English and was dumped at 2020-06-20 being formated by the above very process (` --lang en --date 20200620`).
37+
38+
```bash
39+
python prepare_wikipedia.py --mode download_prepared -o ./
3540
```
3641
### References
3742
- [NVIDIA/DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/LanguageModeling/BERT)
@@ -43,7 +48,7 @@ You can download the OpenWebText from [link](https://skylion007.github.io/OpenWe
4348
After downloading and extracting the OpenWebText (i.e., `tar xf openwebtext.tar.xz`), you can use the following command to preprocess the dataset.
4449

4550
```bash
46-
python prepare_openwebtext.py --input openwebtext/ --output prepared_owt
51+
python prepare_openwebtext.py --input openwebtext/ --output prepared_owt --shuffle
4752
```
4853

4954
In this step, the archived txt are directly read without decompressing.

scripts/datasets/pretrain_corpus/prepare_bookcorpus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def main(args):
7575
filename = os.path.basename(name)
7676
f.extract(name, os.path.join(save_dir, filename))
7777
else:
78+
# TODO(zheyuye), format for pretraining
7879
raise NotImplementedError
7980
else:
8081
raise NotImplementedError

scripts/datasets/pretrain_corpus/prepare_openwebtext.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def extract_files(full_name, output_dir, shuffle=False):
5151
"""
5252
if not full_name.endswith(".xz"):
5353
return
54-
file_prefix = re.split('\.|/',full_name)[1]
55-
with open("{}.txt".format(os.path.join(output_dir, file_prefix)),"w") as fp:
54+
file_prefix = re.split(r'\.|/', full_name)[-2]
55+
file_prefix = file_prefix.replace('urlsf_subset', 'openwebtext-prepared-')
56+
with open("{}.txt".format(os.path.join(output_dir, file_prefix)), "w") as fp:
5657
with tarfile.open(full_name) as t:
5758
txt_names = t.getnames()
5859
if shuffle:
@@ -63,9 +64,9 @@ def extract_files(full_name, output_dir, shuffle=False):
6364
# skip empty line
6465
line = line.strip()
6566
if line:
66-
fp.write(line.decode()+'\n')
67+
fp.write(line.decode() + '\n')
6768
# Two extra line break to mark the document separation
68-
fp.write('\n\n')
69+
fp.write('\n')
6970

7071

7172
@DATA_MAIN_REGISTRY.register('prepare_openwebtext')
@@ -76,11 +77,16 @@ def main(args):
7677
fnames = sorted(os.listdir(args.input))
7778
fnames = [os.path.join(args.input, fname) for fname in fnames]
7879
if args.shuffle:
79-
fnames = random.shuffle(fnames)
80+
random.shuffle(fnames)
8081
print('Start extracting {} files with {} cores'.format(len(fnames), num_process))
8182
start_time = time.time()
8283
with multiprocessing.Pool(num_process) as pool:
83-
iter = pool.imap(functools.partial(extract_files, output_dir=args.output, shuffle=args.shuffle), fnames)
84+
iter = pool.imap(
85+
functools.partial(
86+
extract_files,
87+
output_dir=args.output,
88+
shuffle=args.shuffle),
89+
fnames)
8490
for f_index, _ in enumerate(iter):
8591
if f_index > 0 and f_index % 250 == 0:
8692
elapsed = time.time() - start_time

scripts/datasets/pretrain_corpus/prepare_wikipedia.py

Lines changed: 112 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Prepare the Wikipedia dataset that contain cleaned articles of all languages."""
22
import os
33
import sys
4-
import argparse
54
import glob
6-
from gluonnlp.utils.misc import download
7-
from gluonnlp.registry import DATA_PARSER_REGISTRY, DATA_MAIN_REGISTRY
5+
import math
6+
import time
7+
import tarfile
8+
import argparse
9+
import multiprocessing
10+
11+
from gluonnlp.registry import DATA_MAIN_REGISTRY, DATA_PARSER_REGISTRY
12+
from gluonnlp.utils.misc import download, load_checksum_stats
813

914
_CITATION = """\
1015
@ONLINE {wikidump,
@@ -47,6 +52,13 @@
4752
_BASE_URL_TMPL\
4853
= "https://dumps.wikimedia.org/{lang}wiki/{date}/{lang}wiki-{date}-pages-articles.xml.bz2"
4954
_CURR_DIR = os.path.realpath(os.path.dirname(os.path.realpath(__file__)))
55+
_URL_FILE_STATS_PATH = os.path.join(_CURR_DIR, '..', 'url_checksums', 'wikipedia.txt')
56+
_URL_FILE_STATS = load_checksum_stats(_URL_FILE_STATS_PATH)
57+
58+
_URLS = {
59+
'wikipedia-en-20200620':
60+
'https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz',
61+
}
5062

5163

5264
def get_url(lang, date):
@@ -55,64 +67,72 @@ def get_url(lang, date):
5567

5668
def try_import_wikiextractor():
5769
try:
70+
sys.path.append(_CURR_DIR)
5871
import WikiExtractor
5972
except ImportError:
6073
try:
6174
download(
62-
'https://raw.githubusercontent.com/attardi/wikiextractor/'
63-
'16186e290d9eb0eb3a3784c6c0635a9ed7e855c3/WikiExtractor.py',
75+
'https://raw.githubusercontent.com/attardi/wikiextractor/master/WikiExtractor.py',
6476
path=os.path.join(_CURR_DIR, 'WikiExtractor.py'),
6577
sha1_hash='3c4896a837b75c476d23c037e8d6c7fdfd9a29eb')
78+
sys.path.append(_CURR_DIR)
6679
import WikiExtractor
67-
except:
80+
except BaseException:
6881
raise ImportError('Cannot import WikiExtractor! You can download the "WikiExtractor.py"'
6982
' in https://github.com/attardi/wikiextractor to {}'
7083
.format(_CURR_DIR))
7184
return WikiExtractor
7285

7386

74-
class WikicorpusTextFormatting:
75-
def __init__(self, wiki_path, output_filename, recursive=False):
76-
self.wiki_path = wiki_path
77-
self.recursive = recursive
78-
self.output_filename = output_filename
79-
80-
# This puts one article per line
81-
def merge(self):
82-
with open(self.output_filename, mode='w', newline='\n') as ofile:
83-
for dirname in glob.glob(os.path.join(self.wiki_path, '*'), recursive=False):
84-
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=self.recursive):
85-
print(filename)
86-
article_lines = []
87-
article_open = False
88-
89-
with open(filename, mode='r', newline='\n') as file:
90-
for line in file:
91-
if '<doc id=' in line:
92-
article_open = True
93-
elif '</doc>' in line:
94-
article_open = False
95-
for oline in article_lines[1:]:
96-
if oline != '\n':
97-
ofile.write(oline.rstrip() + " ")
98-
ofile.write("\n\n")
99-
article_lines = []
100-
else:
101-
if article_open:
102-
article_lines.append(line)
87+
def get_formatting_list(wiki_path, recursive=False):
88+
"""
89+
get formatting list of file names from extracted content
90+
"""
91+
filenames = []
92+
for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False):
93+
for filename in glob.glob(os.path.join(dirname, 'wiki_*'), recursive=recursive):
94+
filenames.append(filename)
95+
return filenames
96+
97+
98+
def merge(x):
99+
"""
100+
Puts one article per line
101+
"""
102+
file_list, output_filename = x
103+
article_lines = []
104+
article_open = False
105+
106+
with open(output_filename, mode='w', newline='\n') as ofile:
107+
for filename in file_list:
108+
with open(filename, mode='r', newline='\n') as file:
109+
for line in file:
110+
if '<doc id=' in line:
111+
article_open = True
112+
elif '</doc>' in line:
113+
article_open = False
114+
for oline in article_lines[1:]:
115+
if oline != '\n':
116+
ofile.write(oline.rstrip() + " ")
117+
ofile.write("\n\n")
118+
article_lines = []
119+
else:
120+
if article_open:
121+
article_lines.append(line)
103122

104123

105124
@DATA_PARSER_REGISTRY.register('prepare_wikipedia')
106125
def get_parser():
107126
parser = argparse.ArgumentParser(description='Download and Prepare the Wikipedia')
108127
parser.add_argument('--mode', type=str,
109128
default='download+format',
110-
choices=['download', 'format', 'download+format'],
129+
choices=['download', 'format', 'download+format', 'download_prepared'],
111130
help='Specify the action you want the app to take. '
112131
'"download" means to download the Wikipedia dump. '
113132
'"format" means to extract the content and '
114133
'format it for pretraining. "download+format" means to combine '
115-
'these two options')
134+
'these two options'
135+
'"download_prepared" downloads the prepared txt from S3 directly')
116136
parser.add_argument('--lang', type=str, default='en',
117137
help='Language of the wikipedia dump file.'
118138
'We only support English and Chinese for current version')
@@ -124,8 +144,13 @@ def get_parser():
124144
parser.add_argument("-o", "--output", default="wikicorpus",
125145
help="directory for downloaded or formatted files")
126146
parser.add_argument("-b", "--bytes", default="100M",
127-
help="maximum bytes per output file (default %(default)s)",
147+
help="maximum bytes per extracted file (default %(default)s)",
128148
metavar="n[KMG]")
149+
parser.add_argument("--num_process", type=int, default=8,
150+
help="number of processes for multiprocessing")
151+
parser.add_argument("--num_out_files", type=int, default=1000,
152+
help="Number of desired output files, where each is processed"
153+
" independently by a worker.")
129154
return parser
130155

131156

@@ -145,32 +170,75 @@ def download_wikicorpus(lang, date, output):
145170
return output_file
146171

147172

148-
def format_wikicorpus(input, output, bytes):
173+
def format_wikicorpus(input, output, bytes, num_process, num_out_files):
149174
if input is None:
150175
raise ValueError('input file is empty.')
151176
if not input.endswith('xml.bz2'):
152177
raise ValueError('input file not *.xml.bz2.')
153178
if not os.path.exists(output):
154179
os.makedirs(output)
180+
155181
# Use WikiExtractor to extract the content
156182
WikiExtractor = try_import_wikiextractor()
157183
wiki_path = os.path.join(output, 'extracted')
158184
sys.argv = ['prog', '-b', bytes, '-o', wiki_path, input]
159185
WikiExtractor.main()
160-
output_filename = os.path.join(output, 'wikicorpus_one_article_per_line.txt')
161-
wiki_formatter = WikicorpusTextFormatting(wiki_path, output_filename, recursive=True)
162-
wiki_formatter.merge()
186+
187+
# Merge extracted content into txt files
188+
prepared_path = os.path.join(output, 'prepared_wikipedia')
189+
if not os.path.exists(prepared_path):
190+
os.makedirs(prepared_path)
191+
filenames = get_formatting_list(wiki_path, recursive=True)
192+
num_files = len(filenames)
193+
num_out_files = min(num_out_files, num_files)
194+
file_volume = math.ceil(num_files / num_out_files)
195+
splited_files = [filenames[i: i + file_volume] for i in range(0, num_files, file_volume)]
196+
num_out_files = len(splited_files)
197+
output_files = [
198+
os.path.join(
199+
prepared_path,
200+
"wikipedia-prepared-{}.txt".format(
201+
str(i).zfill(4))) for i in range(num_out_files)]
202+
print("All prepared raw text will be saved in {} txt files".format(num_out_files))
203+
num_process = min(num_process, num_out_files)
204+
print('Start preprocessing {} text files with {} cores'.format(num_files, num_process))
205+
process_args = [(splited_files[i], output_files[i]) for i in range(num_out_files)]
206+
207+
start_time = time.time()
208+
with multiprocessing.Pool(num_process) as pool:
209+
f_read = 0
210+
for i, _ in enumerate(pool.imap(merge, process_args)):
211+
elapsed = time.time() - start_time
212+
f_read += len(splited_files[i])
213+
print("prepared {:} files, Elapsed: {:.2f}s, ETA: {:.2f}s, ".format(
214+
f_read, elapsed, (num_files - f_read) / (num_files / elapsed)))
215+
print("Done preparation within {:.2f} seconds".format(elapsed))
163216

164217

165218
@DATA_MAIN_REGISTRY.register('prepare_wikipedia')
166219
def main(args):
220+
num_process = min(multiprocessing.cpu_count(), args.num_process)
167221
if args.mode == 'download':
168222
download_wikicorpus(args.lang, args.date, args.output)
169223
elif args.mode == 'format':
170-
format_wikicorpus(args.input, args.output, args.bytes)
224+
format_wikicorpus(args.input, args.output, args.bytes, num_process, args.num_out_files)
171225
elif args.mode == 'download+format':
172226
downloaded_file = download_wikicorpus(args.lang, args.date, args.output)
173-
format_wikicorpus(downloaded_file, args.output, args.bytes)
227+
format_wikicorpus(downloaded_file, args.output, args.bytes, num_process, args.num_out_files)
228+
elif args.mode == 'download_prepared':
229+
url = _URLS['wikipedia-en-20200620']
230+
file_hash = _URL_FILE_STATS[url]
231+
target_download_location = os.path.join(args.output,
232+
os.path.basename(url))
233+
download(url, target_download_location, sha1_hash=file_hash)
234+
tar = tarfile.open(target_download_location)
235+
names = tar.getnames()
236+
print('Start unarchiving raw text files')
237+
start_time = time.time()
238+
for name in names:
239+
tar.extract(name, path=args.output)
240+
tar.close()
241+
print("Done unarchiving within {:.2f} seconds".format(time.time() - start_time))
174242
else:
175243
raise NotImplementedError
176244

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/pretrain_corpus/wikipedia-en-20200620.tar.gz 1e1d77c31622744aaa45ff5bfbfca397154d9186 5068070627

scripts/machine_translation/train_transformer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def train(args):
367367
seed=args.seed)
368368
else:
369369
raise NotImplementedError
370-
370+
371371
batchify_fn = bf.Tuple(bf.Pad(), bf.Pad(), bf.Stack(), bf.Stack(), bf.Stack())
372372
train_data_loader = gluon.data.DataLoader(data_train,
373373
batch_sampler=train_batch_sampler,
@@ -387,7 +387,6 @@ def train(args):
387387
log_start_time = time.time()
388388
num_params, num_fixed_params = None, None
389389
# TODO(sxjscience) Add a log metric class
390-
391390
accum_count = 0
392391
loss_denom = 0
393392
n_train_iters = 0
@@ -471,12 +470,10 @@ def train(args):
471470
deduplicate=True)
472471
if args.max_update > 0 and n_train_iters >= args.max_update:
473472
break
474-
475473
if args.epochs > 0:
476474
model.save_parameters(os.path.join(args.save_dir,
477475
'epoch{:d}.params'.format(epoch_id)),
478476
deduplicate=True)
479-
480477
avg_valid_loss = validation(model, val_data_loader, ctx_l)
481478
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
482479
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

0 commit comments

Comments
 (0)