@@ -186,22 +186,23 @@ def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
186
186
self .tgt_vocab = tgt_vocab
187
187
self .normalize_by_length = normalize_by_length
188
188
189
- def _make_shard_state (self , batch , output , range_ , attns ):
189
+ def _make_shard_state (self , batch , output , enc_src , enc_tgt , range_ , attns ):
190
190
"""See base class for args description."""
191
191
if getattr (batch , "alignment" , None ) is None :
192
192
raise AssertionError ("using -copy_attn you need to pass in "
193
193
"-dynamic_dict during preprocess stage." )
194
194
195
195
shard_state = super (CopyGeneratorLossCompute , self )._make_shard_state (
196
- batch , output , range_ , attns )
196
+ batch , output , enc_src , enc_tgt , range_ , attns )
197
197
198
198
shard_state .update ({
199
199
"copy_attn" : attns .get ("copy" ),
200
200
"align" : batch .alignment [range_ [0 ] + 1 : range_ [1 ]]
201
201
})
202
202
return shard_state
203
203
204
- def _compute_loss (self , batch , output , target , copy_attn , align ,
204
+ def _compute_loss (self , batch , normalization , output , target ,
205
+ copy_attn , align , enc_src = None , enc_tgt = None ,
205
206
std_attn = None , coverage_attn = None ):
206
207
"""Compute the loss.
207
208
@@ -244,8 +245,18 @@ def _compute_loss(self, batch, output, target, copy_attn, align,
244
245
offset_align = align [correct_mask ] + len (self .tgt_vocab )
245
246
target_data [correct_mask ] += offset_align
246
247
248
+ if self .lambda_cosine != 0.0 :
249
+ cosine_loss , num_ex = self ._compute_cosine_loss (enc_src , enc_tgt )
250
+ loss += self .lambda_cosine * (cosine_loss / num_ex )
251
+ else :
252
+ cosine_loss = None
253
+ num_ex = 0
254
+
247
255
# Compute sum of perplexities for stats
248
- stats = self ._stats (loss .sum ().clone (), scores_data , target_data )
256
+ stats = self ._stats (loss .sum ().clone (),
257
+ cosine_loss .clone () if cosine_loss is not None
258
+ else cosine_loss ,
259
+ scores_data , target_data , num_ex )
249
260
250
261
# this part looks like it belongs in CopyGeneratorLoss
251
262
if self .normalize_by_length :
0 commit comments