From e667b82a7dc31d782e9ab8e28dedc39bd21cebe8 Mon Sep 17 00:00:00 2001 From: Erick Fonseca Date: Sun, 29 Sep 2019 13:51:34 +0100 Subject: [PATCH 1/2] Fixed ignore_index --- entmax/losses.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/entmax/losses.py b/entmax/losses.py index 0113234..11ec796 100644 --- a/entmax/losses.py +++ b/entmax/losses.py @@ -14,17 +14,22 @@ def __init__(self, ignore_index=-100, reduction="elementwise_mean"): super(_GenericLoss, self).__init__() def forward(self, X, target): + if self.ignore_index is not None: + num_samples = target.size(0) + valid_positions = target != self.ignore_index + target = target[valid_positions] + X = X[valid_positions] + loss = self.loss(X, target) - if self.ignore_index >= 0: - ignored_positions = target == self.ignore_index - size = float((target.size(0) - ignored_positions.sum()).item()) - loss.masked_fill_(ignored_positions, 0.0) - else: - size = float(target.size(0)) + + if self.reduction == "none" and self.ignore_index is not None: + nonzero_loss = loss + loss = torch.zeros(num_samples, device=X.device) + loss[valid_positions] = nonzero_loss if self.reduction == "sum": loss = loss.sum() elif self.reduction == "elementwise_mean": - loss = loss.sum() / size + loss = loss.mean() return loss From 83d7bab1f43b9ff9b36ecbf26c26f25c395f30a2 Mon Sep 17 00:00:00 2001 From: Erick Fonseca Date: Sun, 29 Sep 2019 13:52:50 +0100 Subject: [PATCH 2/2] Added reduction "mean" to loss functions --- entmax/losses.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/entmax/losses.py b/entmax/losses.py index 11ec796..1d6bfc1 100644 --- a/entmax/losses.py +++ b/entmax/losses.py @@ -7,8 +7,10 @@ class _GenericLoss(nn.Module): - def __init__(self, ignore_index=-100, reduction="elementwise_mean"): - assert reduction in ["elementwise_mean", "sum", "none"] + def __init__(self, ignore_index=-100, reduction="mean"): + assert reduction in ["elementwise_mean", "sum", "none", "mean"] + if reduction == "elementwise_mean": + reduction = "mean" self.reduction = reduction self.ignore_index = ignore_index super(_GenericLoss, self).__init__() @@ -28,7 +30,7 @@ def forward(self, X, target): loss[valid_positions] = nonzero_loss if self.reduction == "sum": loss = loss.sum() - elif self.reduction == "elementwise_mean": + elif self.reduction == "mean": loss = loss.mean() return loss @@ -257,7 +259,7 @@ def loss(self, X, target): class SparsemaxLoss(_GenericLoss): - def __init__(self, k=None, ignore_index=-100, reduction="elementwise_mean"): + def __init__(self, k=None, ignore_index=-100, reduction="mean"): self.k = k super(SparsemaxLoss, self).__init__(ignore_index, reduction) @@ -271,7 +273,7 @@ def __init__( alpha=1.5, n_iter=50, ignore_index=-100, - reduction="elementwise_mean", + reduction="mean", ): self.alpha = alpha self.n_iter = n_iter @@ -282,7 +284,7 @@ def loss(self, X, target): class Entmax15Loss(_GenericLoss): - def __init__(self, k=100, ignore_index=-100, reduction="elementwise_mean"): + def __init__(self, k=100, ignore_index=-100, reduction="mean"): self.k = k super(Entmax15Loss, self).__init__(ignore_index, reduction)