Skip to content

Commit 1faea3c

Browse files
move cosine loss compute to function, fix some args
1 parent e99daaf commit 1faea3c

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

onmt/modules/copy_generator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,23 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
186186
self.tgt_vocab = tgt_vocab
187187
self.normalize_by_length = normalize_by_length
188188

189-
def _make_shard_state(self, batch, output, range_, attns):
189+
def _make_shard_state(self, batch, output, enc_src, enc_tgt, range_, attns):
190190
"""See base class for args description."""
191191
if getattr(batch, "alignment", None) is None:
192192
raise AssertionError("using -copy_attn you need to pass in "
193193
"-dynamic_dict during preprocess stage.")
194194

195195
shard_state = super(CopyGeneratorLossCompute, self)._make_shard_state(
196-
batch, output, range_, attns)
196+
batch, output, enc_src, enc_tgt, range_, attns)
197197

198198
shard_state.update({
199199
"copy_attn": attns.get("copy"),
200200
"align": batch.alignment[range_[0] + 1: range_[1]]
201201
})
202202
return shard_state
203203

204-
def _compute_loss(self, batch, output, target, copy_attn, align,
204+
def _compute_loss(self, batch, normalization, output, target,
205+
copy_attn, align, enc_src=None, enc_tgt=None,
205206
std_attn=None, coverage_attn=None):
206207
"""Compute the loss.
207208
@@ -244,8 +245,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
244245
offset_align = align[correct_mask] + len(self.tgt_vocab)
245246
target_data[correct_mask] += offset_align
246247

248+
if self.lambda_cosine != 0.0:
249+
cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
250+
loss += self.lambda_cosine * (cosine_loss / num_ex)
251+
else:
252+
cosine_loss = None
253+
num_ex = 0
254+
247255
# Compute sum of perplexities for stats
248-
stats = self._stats(loss.sum().clone(), scores_data, target_data)
256+
stats = self._stats(loss.sum().clone(),
257+
cosine_loss.clone() if cosine_loss is not None
258+
else cosine_loss,
259+
scores_data, target_data, num_ex)
249260

250261
# this part looks like it belongs in CopyGeneratorLoss
251262
if self.normalize_by_length:

onmt/utils/loss.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def __init__(self, criterion, generator):
9292
def padding_idx(self):
9393
return self.criterion.ignore_index
9494

95-
def _make_shard_state(self, batch, output, range_, attns=None):
95+
def _make_shard_state(self, batch, enc_src, enc_tgt,
96+
output, range_, attns=None):
9697
"""
9798
Make shard state dictionary for shards() to return iterable
9899
shards for efficient loss computation. Subclass must define
@@ -315,14 +316,7 @@ def _compute_loss(self, batch, normalization, output, target,
315316
loss = loss/float(normalization)
316317

317318
if self.lambda_cosine != 0.0:
318-
max_src = enc_src.max(axis=0)[0]
319-
max_tgt = enc_tgt.max(axis=0)[0]
320-
cosine_loss = torch.nn.functional.cosine_similarity(
321-
max_src.float(), max_tgt.float(), dim=1)
322-
ones = torch.ones(cosine_loss.size()).to(cosine_loss.device)
323-
cosine_loss = ones - cosine_loss
324-
num_ex = cosine_loss.size(0)
325-
cosine_loss = cosine_loss.sum()
319+
cosine_loss, num_ex = self._compute_cosine_loss(enc_src, enc_tgt)
326320
loss += self.lambda_cosine * (cosine_loss / num_ex)
327321
else:
328322
cosine_loss = None
@@ -340,6 +334,16 @@ def _compute_coverage_loss(self, std_attn, coverage_attn):
340334
covloss *= self.lambda_coverage
341335
return covloss
342336

337+
def _compute_cosine_loss(self, enc_src, enc_tgt):
338+
max_src = enc_src.max(axis=0)[0]
339+
max_tgt = enc_tgt.max(axis=0)[0]
340+
cosine_loss = torch.nn.functional.cosine_similarity(
341+
max_src.float(), max_tgt.float(), dim=1)
342+
ones = torch.ones(cosine_loss.size()).to(cosine_loss.device)
343+
cosine_loss = ones - cosine_loss
344+
num_ex = cosine_loss.size(0)
345+
return cosine_loss.sum(), num_ex
346+
343347
def _compute_alignement_loss(self, align_head, ref_align):
344348
"""Compute loss between 2 partial alignment matrix."""
345349
# align_head contains value in [0, 1) presenting attn prob,

0 commit comments

Comments
 (0)