Skip to content

Add ListMLE Loss #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os
import shutil

import pre_commit
import namex

PACKAGE = "keras_rs"
Expand Down
3 changes: 2 additions & 1 deletion keras_rs/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
since your modifications would be overwritten.
"""


from keras_rs import layers as layers
from keras_rs import losses as losses
from keras_rs import metrics as metrics
from keras_rs.src.version import __version__ as __version__
from keras_rs.src.version import version as version
from keras_rs.src.version import __version__ as __version__
41 changes: 11 additions & 30 deletions keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,15 @@
since your modifications would be overwritten.
"""

from keras_rs.src.layers.embedding.distributed_embedding import (
DistributedEmbedding as DistributedEmbedding,
)
from keras_rs.src.layers.embedding.distributed_embedding_config import (
FeatureConfig as FeatureConfig,
)
from keras_rs.src.layers.embedding.distributed_embedding_config import (
TableConfig as TableConfig,
)
from keras_rs.src.layers.embedding.embed_reduce import (
EmbedReduce as EmbedReduce,
)
from keras_rs.src.layers.feature_interaction.dot_interaction import (
DotInteraction as DotInteraction,
)
from keras_rs.src.layers.feature_interaction.feature_cross import (
FeatureCross as FeatureCross,
)
from keras_rs.src.layers.retrieval.brute_force_retrieval import (
BruteForceRetrieval as BruteForceRetrieval,
)
from keras_rs.src.layers.retrieval.hard_negative_mining import (
HardNegativeMining as HardNegativeMining,
)
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
RemoveAccidentalHits as RemoveAccidentalHits,
)

from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding
from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig
from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig
from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce
from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction
from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross
from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval
from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining
from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits
from keras_rs.src.layers.retrieval.retrieval import Retrieval as Retrieval
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
SamplingProbabilityCorrection as SamplingProbabilityCorrection,
)
from keras_rs.src.layers.retrieval.sampling_probability_correction import SamplingProbabilityCorrection as SamplingProbabilityCorrection
18 changes: 6 additions & 12 deletions keras_rs/api/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
since your modifications would be overwritten.
"""

from keras_rs.src.losses.pairwise_hinge_loss import (
PairwiseHingeLoss as PairwiseHingeLoss,
)
from keras_rs.src.losses.pairwise_logistic_loss import (
PairwiseLogisticLoss as PairwiseLogisticLoss,
)
from keras_rs.src.losses.pairwise_mean_squared_error import (
PairwiseMeanSquaredError as PairwiseMeanSquaredError,
)
from keras_rs.src.losses.pairwise_soft_zero_one_loss import (
PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss,
)

from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss as PairwiseHingeLoss
from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss as PairwiseLogisticLoss
from keras_rs.src.losses.pairwise_mean_squared_error import PairwiseMeanSquaredError as PairwiseMeanSquaredError
from keras_rs.src.losses.pairwise_soft_zero_one_loss import PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss
9 changes: 3 additions & 6 deletions keras_rs/api/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@
since your modifications would be overwritten.
"""


from keras_rs.src.metrics.dcg import DCG as DCG
from keras_rs.src.metrics.mean_average_precision import (
MeanAveragePrecision as MeanAveragePrecision,
)
from keras_rs.src.metrics.mean_reciprocal_rank import (
MeanReciprocalRank as MeanReciprocalRank,
)
from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision as MeanAveragePrecision
from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank as MeanReciprocalRank
from keras_rs.src.metrics.ndcg import NDCG as NDCG
from keras_rs.src.metrics.precision_at_k import PrecisionAtK as PrecisionAtK
from keras_rs.src.metrics.recall_at_k import RecallAtK as RecallAtK
195 changes: 195 additions & 0 deletions keras_rs/src/losses/list_mle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from typing import Any

import keras
from keras import ops

from keras_rs.src import types
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
from keras_rs.src.api_export import keras_rs_export

@keras_rs_export("keras_rs.losses.ListMLELoss")
class ListMLELoss(keras.losses.Loss):
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.

ListMLE loss is a listwise ranking loss that maximizes the likelihood of
the ground truth ranking. It works by:
1. Sorting items by their relevance scores (labels)
2. Computing the probability of observing this ranking given the
predicted scores
3. Maximizing this likelihood (minimizing negative log-likelihood)

The loss is computed as the negative log-likelihood of the ground truth
ranking given the predicted scores:

```
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
```

where s_i is the predicted score for item i in the sorted order.

Args:
temperature: Temperature parameter for scaling logits. Higher values
make the probability distribution more uniform. Defaults to 1.0.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`.

Examples:
```python
# Basic usage
loss_fn = ListMLELoss()

# With temperature scaling
loss_fn = ListMLELoss(temperature=0.5)

# Example with synthetic data
y_true = [[3, 2, 1, 0]] # Relevance scores
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
loss = loss_fn(y_true, y_pred)
```
"""

def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
super().__init__(**kwargs)

if temperature <= 0.0:
raise ValueError(
f"`temperature` should be a positive float. Received: "
f"`temperature` = {temperature}."
)

self.temperature = temperature
self._epsilon = 1e-10


def compute_unreduced_loss(
self,
labels: types.Tensor,
logits: types.Tensor,
mask: types.Tensor | None = None,
) -> tuple[types.Tensor, types.Tensor]:
"""Compute the unreduced ListMLE loss.

Args:
labels: Ground truth relevance scores of
shape [batch_size,list_size].
logits: Predicted scores of shape [batch_size, list_size].
mask: Optional mask of shape [batch_size, list_size].

Returns:
Tuple of (losses, weights) where losses has shape [batch_size, 1]
and weights has the same shape.
"""

valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))

if mask is not None:
valid_mask = ops.logical_and(valid_mask, ops.cast(mask, dtype="bool"))

num_valid_items = ops.sum(ops.cast(valid_mask, dtype=labels.dtype),
axis=1, keepdims=True)

batch_has_valid_items = ops.greater(num_valid_items, 0.0)

labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9))
logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9))

sorted_indices = ops.argsort(-labels_for_sorting, axis=-1)

sorted_logits = ops.take_along_axis(logits_masked, sorted_indices, axis=-1)
sorted_valid_mask = ops.take_along_axis(valid_mask, sorted_indices, axis=-1)

sorted_logits = ops.divide(
sorted_logits,
ops.cast(self.temperature, dtype=sorted_logits.dtype)
)

valid_logits_for_max = ops.where(sorted_valid_mask, sorted_logits,
ops.full_like(sorted_logits, -1e9))
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)

raw_max = ops.where(batch_has_valid_items, raw_max, ops.zeros_like(raw_max))
sorted_logits = sorted_logits - raw_max


exp_logits = ops.exp(sorted_logits)

exp_logits = ops.where(sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits))

reversed_exp = ops.flip(exp_logits, axis=1)
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)


log_normalizers = ops.log(cumsum_from_right + self._epsilon)
log_probs = sorted_logits - log_normalizers

log_probs = ops.where(sorted_valid_mask, log_probs, ops.zeros_like(log_probs))


negative_log_likelihood = -ops.sum(log_probs, axis=1, keepdims=True)

negative_log_likelihood = ops.where(batch_has_valid_items, negative_log_likelihood,
ops.zeros_like(negative_log_likelihood))

weights = ops.ones_like(negative_log_likelihood)

return negative_log_likelihood, weights

def call(
self,
y_true: types.Tensor,
y_pred: types.Tensor,
) -> types.Tensor:
"""Compute the ListMLE loss.

Args:
y_true: tensor or dict. Ground truth values. If tensor, of shape
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
for batched inputs. If an item has a label of -1, it is ignored
in loss computation. If it is a dictionary, it should have two
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
elements in loss computation.
y_pred: tensor. The predicted values, of shape `(list_size)` for
unbatched inputs or `(batch_size, list_size)` for batched
inputs. Should be of the same shape as `y_true`.

Returns:
The loss tensor of shape [batch_size].
"""
mask = None
if isinstance(y_true, dict):
if "labels" not in y_true:
raise ValueError(
'`"labels"` should be present in `y_true`. Received: '
f"`y_true` = {y_true}"
)

mask = y_true.get("mask", None)
y_true = y_true["labels"]

y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
if mask is not None:
mask = ops.convert_to_tensor(mask)

y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
y_true, y_pred, mask
)

losses, weights = self.compute_unreduced_loss(
labels=y_true, logits=y_pred, mask=mask
)
losses = ops.multiply(losses, weights)
losses = ops.squeeze(losses, axis=-1)
return losses


def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update({"temperature": self.temperature})
return config
88 changes: 88 additions & 0 deletions keras_rs/src/losses/list_mle_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.losses import deserialize
from keras.losses import serialize

from keras_rs.src import testing
from keras_rs.src.losses.list_mle_loss import ListMLELoss


class ListMLELossTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])

self.batched_scores = ops.array(
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]]
)
self.batched_labels = ops.array(
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]]
)
self.expected_output = ops.array([6.865693, 3.088192])

def test_unbatched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(
y_true=self.unbatched_labels, y_pred=self.unbatched_scores
)
self.assertEqual(output.shape, (1,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertAllClose(output, [self.expected_output[0]], atol=1e-5)

def test_batched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
self.assertEqual(output.shape, (2,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertTrue(ops.convert_to_numpy(output[1]) > 0)
self.assertAllClose(output, self.expected_output, atol=1e-5)

def test_temperature(self):

loss_temp = ListMLELoss(temperature=0.5, reduction="none")
output_temp = loss_temp(y_true=self.batched_labels, y_pred=self.batched_scores)

self.assertAllClose(output_temp,[10.969891,2.1283305],atol=1e-5,
)

def test_invalid_input_rank(self):
rank_1_input = ops.ones((2, 3, 4))

loss = ListMLELoss()
with self.assertRaises(ValueError):
loss(y_true=rank_1_input, y_pred=rank_1_input)

def test_loss_reduction(self):
loss = ListMLELoss(reduction="sum_over_batch_size")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)

self.assertAlmostEqual(ops.convert_to_numpy(output), 4.9769425, places=5)

def test_scalar_sample_weight(self):
sample_weight = ops.array(5.0)
loss = ListMLELoss(reduction="none")

output = loss(
y_true=self.batched_labels,
y_pred=self.batched_scores,
sample_weight=sample_weight,
)

self.assertAllClose(output, self.expected_output * sample_weight, atol=1e-5)

def test_model_fit(self):
inputs = keras.Input(shape=(20,), dtype="float32")
outputs = keras.layers.Dense(5)(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(loss=ListMLELoss(), optimizer="adam")
model.fit(
x=keras.random.normal((2, 20)),
y=keras.random.randint((2, 5), minval=0, maxval=2),
)

def test_serialization(self):
loss = ListMLELoss(temperature=0.8)
restored = deserialize(serialize(loss))
self.assertDictEqual(loss.get_config(), restored.get_config())