Skip to content

Commit 176bafa

Browse files
author
Raghavendra Sugeeth P S
committed
Merge branch 'OpenNMT-master'
2 parents 43c3869 + d9faa98 commit 176bafa

File tree

6 files changed

+50
-30
lines changed

6 files changed

+50
-30
lines changed

onmt/inputters/image_dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def read(self, images, side, img_dir=None):
7575
img = transforms.ToTensor()(
7676
Image.fromarray(cv2.imread(img_path, 0)))
7777
else:
78-
img = transforms.ToTensor()(Image.open(img_path))
78+
img = Image.open(img_path).convert('RGB')
79+
img = transforms.ToTensor()(img)
7980
if self.truncate and self.truncate != (0, 0):
8081
if not (img.size(1) <= self.truncate[0]
8182
and img.size(2) <= self.truncate[1]):

onmt/tests/test_beam_search.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,18 @@ def test_advance_with_all_repeats_gets_blocked(self):
6161
# (but it's still the best score, thus we have
6262
# [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
6363
expected_scores = torch.tensor(
64-
[0] + [-float('inf')] * (beam_sz - 1))\
65-
.repeat(batch_sz, 1)
66-
expected_scores[:, 0] = self.BLOCKED_SCORE
64+
[self.BLOCKED_SCORE] + [-float('inf')] * (beam_sz - 1)
65+
).repeat(batch_sz, 1)
6766
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
6867
else:
6968
# repetitions keeps maximizing score
7069
# index 0 has been blocked, so repeating=>+0.0 score
7170
# other indexes are -inf so repeating=>BLOCKED_SCORE
7271
# which is higher
7372
expected_scores = torch.tensor(
74-
[0] + [-float('inf')] * (beam_sz - 1))\
75-
.repeat(batch_sz, 1)
76-
expected_scores[:, :] = self.BLOCKED_SCORE
77-
expected_scores = torch.tensor(
78-
self.BLOCKED_SCORE).repeat(batch_sz, beam_sz)
73+
[self.BLOCKED_SCORE] + [-float('inf')] * (beam_sz - 1)
74+
).repeat(batch_sz, 1)
75+
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
7976

8077
def test_advance_with_some_repeats_gets_blocked(self):
8178
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
@@ -137,7 +134,8 @@ def test_advance_with_some_repeats_gets_blocked(self):
137134

138135
expected = torch.full([batch_sz, beam_sz], float("-inf"))
139136
expected[:, 0] = no_repeat_score
140-
expected[:, 1:] = self.BLOCKED_SCORE
137+
expected[:, 1:3] = self.BLOCKED_SCORE
138+
expected[:, 3:] = float("-inf")
141139
self.assertTrue(
142140
beam.topk_log_probs.equal(expected))
143141

onmt/trainer.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _update_average(self, step):
204204
self.moving_average = copy_params
205205
else:
206206
average_decay = max(self.average_decay,
207-
1 - (step + 1)/(step + 10))
207+
1 - (step + 1) / (step + 10))
208208
for (i, avg), cpt in zip(enumerate(self.moving_average),
209209
self.model.parameters()):
210210
self.moving_average[i] = \
@@ -306,10 +306,9 @@ def train(self,
306306
break
307307

308308
if (self.model_saver is not None
309-
and (save_checkpoint_steps != 0
310-
and step % save_checkpoint_steps == 0)):
311-
self.model_saver.save(step, is_best,
312-
moving_average=self.moving_average)
309+
and (save_checkpoint_steps != 0
310+
and step % save_checkpoint_steps == 0)):
311+
self.model_saver.save(step, moving_average=self.moving_average)
313312

314313
if train_steps > 0 and step >= train_steps:
315314
break
@@ -344,7 +343,7 @@ def validate(self, valid_iter, moving_average=None):
344343

345344
for batch in valid_iter:
346345
src, src_lengths = batch.src if isinstance(batch.src, tuple) \
347-
else (batch.src, None)
346+
else (batch.src, None)
348347
tgt = batch.tgt
349348

350349
with torch.cuda.amp.autocast(enabled=self.optim.amp):
@@ -390,7 +389,7 @@ def _gradient_accumulation(self, true_batches, normalization, total_stats,
390389
tgt_outer = batch.tgt
391390

392391
bptt = False
393-
for j in range(0, target_size-1, trunc_size):
392+
for j in range(0, target_size - 1, trunc_size):
394393
# 1. Create truncated target.
395394
tgt = tgt_outer[j: j + trunc_size]
396395

@@ -488,7 +487,12 @@ def _maybe_report_training(self, step, num_steps, learning_rate,
488487
"""
489488
if self.report_manager is not None:
490489
return self.report_manager.report_training(
491-
step, num_steps, learning_rate, report_stats,
490+
step,
491+
num_steps,
492+
learning_rate,
493+
None if self.earlystopper is None
494+
else self.earlystopper.current_tolerance,
495+
report_stats,
492496
multigpu=self.n_gpu > 1)
493497

494498
def _report_step(self, learning_rate, step, train_stats=None,
@@ -499,7 +503,10 @@ def _report_step(self, learning_rate, step, train_stats=None,
499503
"""
500504
if self.report_manager is not None:
501505
return self.report_manager.report_step(
502-
learning_rate, step, train_stats=train_stats,
506+
learning_rate,
507+
None if self.earlystopper is None
508+
else self.earlystopper.current_tolerance,
509+
step, train_stats=train_stats,
503510
valid_stats=valid_stats)
504511

505512
def maybe_noise_source(self, batch):

onmt/translate/decode_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from copy import deepcopy
23

34

45
class DecodeStrategy(object):
@@ -184,7 +185,7 @@ def maybe_update_forbidden_tokens(self):
184185
# Reordering forbidden_tokens following beam selection
185186
# We rebuild a dict to ensure we get the value and not the pointer
186187
forbidden_tokens.append(
187-
dict(self.forbidden_tokens[path_idx]))
188+
deepcopy(self.forbidden_tokens[path_idx]))
188189

189190
# Grabing the newly selected tokens and associated ngram
190191
current_ngram = tuple(seq[-n:].tolist())

onmt/utils/report_manager.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def start(self):
4949
def log(self, *args, **kwargs):
5050
logger.info(*args, **kwargs)
5151

52-
def report_training(self, step, num_steps, learning_rate,
52+
def report_training(self, step, num_steps, learning_rate, patience,
5353
report_stats, multigpu=False):
5454
"""
5555
This is the user-defined batch-level traing progress
@@ -72,7 +72,7 @@ def report_training(self, step, num_steps, learning_rate,
7272
report_stats = \
7373
onmt.utils.Statistics.all_gather_stats(report_stats)
7474
self._report_training(
75-
step, num_steps, learning_rate, report_stats)
75+
step, num_steps, learning_rate, patience, report_stats)
7676
return onmt.utils.Statistics()
7777
else:
7878
return report_stats
@@ -81,17 +81,22 @@ def _report_training(self, *args, **kwargs):
8181
""" To be overridden """
8282
raise NotImplementedError()
8383

84-
def report_step(self, lr, step, train_stats=None, valid_stats=None):
84+
def report_step(self, lr, patience, step, train_stats=None,
85+
valid_stats=None):
8586
"""
8687
Report stats of a step
8788
8889
Args:
90+
lr(float): current learning rate
91+
patience(int): current patience
92+
step(int): current step
8993
train_stats(Statistics): training stats
9094
valid_stats(Statistics): validation stats
91-
lr(float): current learning rate
9295
"""
9396
self._report_step(
94-
lr, step, train_stats=train_stats, valid_stats=valid_stats)
97+
lr, patience, step,
98+
train_stats=train_stats,
99+
valid_stats=valid_stats)
95100

96101
def _report_step(self, *args, **kwargs):
97102
raise NotImplementedError()
@@ -111,12 +116,13 @@ def __init__(self, report_every, start_time=-1., tensorboard_writer=None):
111116
super(ReportMgr, self).__init__(report_every, start_time)
112117
self.tensorboard_writer = tensorboard_writer
113118

114-
def maybe_log_tensorboard(self, stats, prefix, learning_rate, step):
119+
def maybe_log_tensorboard(self, stats, prefix, learning_rate,
120+
patience, step):
115121
if self.tensorboard_writer is not None:
116122
stats.log_tensorboard(
117-
prefix, self.tensorboard_writer, learning_rate, step)
123+
prefix, self.tensorboard_writer, learning_rate, patience, step)
118124

119-
def _report_training(self, step, num_steps, learning_rate,
125+
def _report_training(self, step, num_steps, learning_rate, patience,
120126
report_stats):
121127
"""
122128
See base class method `ReportMgrBase.report_training`.
@@ -127,12 +133,15 @@ def _report_training(self, step, num_steps, learning_rate,
127133
self.maybe_log_tensorboard(report_stats,
128134
"progress",
129135
learning_rate,
136+
patience,
130137
step)
131138
report_stats = onmt.utils.Statistics()
132139

133140
return report_stats
134141

135-
def _report_step(self, lr, step, train_stats=None, valid_stats=None):
142+
def _report_step(self, lr, patience, step,
143+
train_stats=None,
144+
valid_stats=None):
136145
"""
137146
See base class method `ReportMgrBase.report_step`.
138147
"""
@@ -143,6 +152,7 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
143152
self.maybe_log_tensorboard(train_stats,
144153
"train",
145154
lr,
155+
patience,
146156
step)
147157

148158
if valid_stats is not None:
@@ -152,4 +162,5 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None):
152162
self.maybe_log_tensorboard(valid_stats,
153163
"valid",
154164
lr,
165+
patience,
155166
step)

onmt/utils/statistics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ def output(self, step, num_steps, learning_rate, start):
126126
time.time() - start))
127127
sys.stdout.flush()
128128

129-
def log_tensorboard(self, prefix, writer, learning_rate, step):
129+
def log_tensorboard(self, prefix, writer, learning_rate, patience, step):
130130
""" display statistics to tensorboard """
131131
t = self.elapsed_time()
132132
writer.add_scalar(prefix + "/xent", self.xent(), step)
133133
writer.add_scalar(prefix + "/ppl", self.ppl(), step)
134134
writer.add_scalar(prefix + "/accuracy", self.accuracy(), step)
135135
writer.add_scalar(prefix + "/tgtper", self.n_words / t, step)
136136
writer.add_scalar(prefix + "/lr", learning_rate, step)
137+
if patience is not None:
138+
writer.add_scalar(prefix + "/patience", patience, step)

0 commit comments

Comments
 (0)