Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 95d0214

Browse files
martinpopelcopybara-github
authored andcommitted
Merge of PR #1834
PiperOrigin-RevId: 321903600
1 parent 2ea8ec1 commit 95d0214

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
"tensor2tensor.data_generators.subject_verb_agreement",
8282
"tensor2tensor.data_generators.timeseries",
8383
"tensor2tensor.data_generators.transduction_problems",
84+
"tensor2tensor.data_generators.translate_encs_cubbitt",
8485
"tensor2tensor.data_generators.translate_encs",
8586
"tensor2tensor.data_generators.translate_ende",
8687
"tensor2tensor.data_generators.translate_enes",
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Data generators for English-Czech backtranslation NMT data-sets.
17+
18+
To use this problem you need to provide backtranslated (synthetic) data to
19+
tmp_dir (cs_mono_{en,cs}.txt{0,1,2} - each file of a similar size to the
20+
authentic training data).
21+
You can either translate the monolingual data yourself or you can download
22+
"csmono" data from CzEng2.0 (http://ufal.mff.cuni.cz/czeng, registration needed)
23+
which comes with synthetic translations into English using a
24+
backtranslation-trained model, thus the final model will be using
25+
"iterated" backtranslation.
26+
27+
To get the best results out of the Block-Backtranslation
28+
(where blocks of synthetic and authentic training data are concatenated
29+
without shuffling), you should use checkpoint averaging (see t2t-avg-all).
30+
"""
31+
32+
from __future__ import absolute_import
33+
from __future__ import division
34+
from __future__ import print_function
35+
36+
import os
37+
38+
from tensor2tensor.data_generators import problem
39+
from tensor2tensor.data_generators import text_problems
40+
from tensor2tensor.data_generators import translate
41+
from tensor2tensor.data_generators import translate_encs
42+
from tensor2tensor.utils import registry
43+
44+
45+
@registry.register_problem
46+
class TranslateEncsCubbitt(translate_encs.TranslateEncsWmt32k):
47+
"""Problem spec for English-Czech CUBBITT (CUni Block-Backtranslation-Improved Transformer Translation)."""
48+
49+
@property
50+
def use_vocab_from_other_problem(self):
51+
return translate_encs.TranslateEncsWmt32k()
52+
53+
@property
54+
def already_shuffled(self):
55+
return True
56+
57+
@property
58+
def skip_random_fraction_when_training(self):
59+
return False
60+
61+
@property
62+
def backtranslate_data_filenames(self):
63+
"""List of pairs of files with matched back-translated data."""
64+
# Files must be placed in tmp_dir, each similar size to authentic data.
65+
return [("cs_mono_en.txt%d" % i, "cs_mono_cs.txt%d" % i) for i in [0, 1, 2]]
66+
67+
@property
68+
def dataset_splits(self):
69+
"""Splits of data to produce and number of output shards for each."""
70+
return [{
71+
"split": problem.DatasetSplit.TRAIN,
72+
"shards": 1, # Use just 1 shard so as to not mix data.
73+
}, {
74+
"split": problem.DatasetSplit.EVAL,
75+
"shards": 1,
76+
}]
77+
78+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
79+
datasets = self.source_data_files(dataset_split)
80+
tag = "train" if dataset_split == problem.DatasetSplit.TRAIN else "dev"
81+
data_path = translate.compile_data(
82+
tmp_dir, datasets, "%s-compiled-%s" % (self.name, tag))
83+
# For eval, use authentic data.
84+
if dataset_split != problem.DatasetSplit.TRAIN:
85+
for example in text_problems.text2text_txt_iterator(
86+
data_path + ".lang1", data_path + ".lang2"):
87+
yield example
88+
else: # For training, mix synthetic and authentic data as follows.
89+
for (file1, file2) in self.backtranslate_data_filenames:
90+
path1 = os.path.join(tmp_dir, file1)
91+
path2 = os.path.join(tmp_dir, file2)
92+
# Synthetic data first.
93+
for example in text_problems.text2text_txt_iterator(path1, path2):
94+
yield example
95+
# Now authentic data.
96+
for example in text_problems.text2text_txt_iterator(
97+
data_path + ".lang1", data_path + ".lang2"):
98+
yield example

tensor2tensor/models/transformer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,19 @@ def transformer_base_multistep8():
21862186
return hparams
21872187

21882188

2189+
@registry.register_hparams
2190+
def transformer_cubbitt():
2191+
"""Transformer hyperparameters used in CUBBITT experiments."""
2192+
hparams = transformer_big_single_gpu()
2193+
hparams.learning_rate_schedule = "rsqrt_decay"
2194+
hparams.batch_size = 2900
2195+
hparams.learning_rate_warmup_steps = 8000
2196+
hparams.max_length = 150
2197+
hparams.layer_prepostprocess_dropout = 0
2198+
hparams.optimizer = "Adafactor"
2199+
return hparams
2200+
2201+
21892202
@registry.register_hparams
21902203
def transformer_parsing_base():
21912204
"""HParams for parsing on WSJ only."""

0 commit comments

Comments
 (0)