Skip to content

Commit 306b2e5

Browse files
broadcast instead of explicitly create ones
1 parent 9d26360 commit 306b2e5

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

onmt/utils/loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,7 @@ def _compute_cosine_loss(self, enc_src, enc_tgt):
339339
max_tgt = enc_tgt.max(axis=0)[0]
340340
cosine_loss = torch.nn.functional.cosine_similarity(
341341
max_src.float(), max_tgt.float(), dim=1)
342-
ones = torch.ones(cosine_loss.size()).to(cosine_loss.device)
343-
cosine_loss = ones - cosine_loss
342+
cosine_loss = 1 - cosine_loss
344343
num_ex = cosine_loss.size(0)
345344
return cosine_loss.sum(), num_ex
346345

0 commit comments

Comments
 (0)