diff --git a/api_gen.py b/api_gen.py index da0925c..c8c87ba 100755 --- a/api_gen.py +++ b/api_gen.py @@ -8,7 +8,7 @@ import os import shutil - +import pre_commit import namex PACKAGE = "keras_rs" diff --git a/keras_rs/api/__init__.py b/keras_rs/api/__init__.py index 09ff157..121eab4 100644 --- a/keras_rs/api/__init__.py +++ b/keras_rs/api/__init__.py @@ -4,8 +4,9 @@ since your modifications would be overwritten. """ + from keras_rs import layers as layers from keras_rs import losses as losses from keras_rs import metrics as metrics -from keras_rs.src.version import __version__ as __version__ from keras_rs.src.version import version as version +from keras_rs.src.version import __version__ as __version__ diff --git a/keras_rs/api/layers/__init__.py b/keras_rs/api/layers/__init__.py index 8d740e8..2fb894f 100644 --- a/keras_rs/api/layers/__init__.py +++ b/keras_rs/api/layers/__init__.py @@ -4,34 +4,15 @@ since your modifications would be overwritten. """ -from keras_rs.src.layers.embedding.distributed_embedding import ( - DistributedEmbedding as DistributedEmbedding, -) -from keras_rs.src.layers.embedding.distributed_embedding_config import ( - FeatureConfig as FeatureConfig, -) -from keras_rs.src.layers.embedding.distributed_embedding_config import ( - TableConfig as TableConfig, -) -from keras_rs.src.layers.embedding.embed_reduce import ( - EmbedReduce as EmbedReduce, -) -from keras_rs.src.layers.feature_interaction.dot_interaction import ( - DotInteraction as DotInteraction, -) -from keras_rs.src.layers.feature_interaction.feature_cross import ( - FeatureCross as FeatureCross, -) -from keras_rs.src.layers.retrieval.brute_force_retrieval import ( - BruteForceRetrieval as BruteForceRetrieval, -) -from keras_rs.src.layers.retrieval.hard_negative_mining import ( - HardNegativeMining as HardNegativeMining, -) -from keras_rs.src.layers.retrieval.remove_accidental_hits import ( - RemoveAccidentalHits as RemoveAccidentalHits, -) + +from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding +from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig +from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig +from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce +from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction +from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross +from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval +from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining +from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits from keras_rs.src.layers.retrieval.retrieval import Retrieval as Retrieval -from keras_rs.src.layers.retrieval.sampling_probability_correction import ( - SamplingProbabilityCorrection as SamplingProbabilityCorrection, -) +from keras_rs.src.layers.retrieval.sampling_probability_correction import SamplingProbabilityCorrection as SamplingProbabilityCorrection diff --git a/keras_rs/api/losses/__init__.py b/keras_rs/api/losses/__init__.py index 152b449..92f2a6d 100644 --- a/keras_rs/api/losses/__init__.py +++ b/keras_rs/api/losses/__init__.py @@ -4,15 +4,9 @@ since your modifications would be overwritten. """ -from keras_rs.src.losses.pairwise_hinge_loss import ( - PairwiseHingeLoss as PairwiseHingeLoss, -) -from keras_rs.src.losses.pairwise_logistic_loss import ( - PairwiseLogisticLoss as PairwiseLogisticLoss, -) -from keras_rs.src.losses.pairwise_mean_squared_error import ( - PairwiseMeanSquaredError as PairwiseMeanSquaredError, -) -from keras_rs.src.losses.pairwise_soft_zero_one_loss import ( - PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss, -) + +from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss +from keras_rs.src.losses.pairwise_hinge_loss import PairwiseHingeLoss as PairwiseHingeLoss +from keras_rs.src.losses.pairwise_logistic_loss import PairwiseLogisticLoss as PairwiseLogisticLoss +from keras_rs.src.losses.pairwise_mean_squared_error import PairwiseMeanSquaredError as PairwiseMeanSquaredError +from keras_rs.src.losses.pairwise_soft_zero_one_loss import PairwiseSoftZeroOneLoss as PairwiseSoftZeroOneLoss diff --git a/keras_rs/api/metrics/__init__.py b/keras_rs/api/metrics/__init__.py index c66c78a..176c251 100644 --- a/keras_rs/api/metrics/__init__.py +++ b/keras_rs/api/metrics/__init__.py @@ -4,13 +4,10 @@ since your modifications would be overwritten. """ + from keras_rs.src.metrics.dcg import DCG as DCG -from keras_rs.src.metrics.mean_average_precision import ( - MeanAveragePrecision as MeanAveragePrecision, -) -from keras_rs.src.metrics.mean_reciprocal_rank import ( - MeanReciprocalRank as MeanReciprocalRank, -) +from keras_rs.src.metrics.mean_average_precision import MeanAveragePrecision as MeanAveragePrecision +from keras_rs.src.metrics.mean_reciprocal_rank import MeanReciprocalRank as MeanReciprocalRank from keras_rs.src.metrics.ndcg import NDCG as NDCG from keras_rs.src.metrics.precision_at_k import PrecisionAtK as PrecisionAtK from keras_rs.src.metrics.recall_at_k import RecallAtK as RecallAtK diff --git a/keras_rs/src/losses/list_mle_loss.py b/keras_rs/src/losses/list_mle_loss.py new file mode 100644 index 0000000..2927a96 --- /dev/null +++ b/keras_rs/src/losses/list_mle_loss.py @@ -0,0 +1,195 @@ +from typing import Any + +import keras +from keras import ops + +from keras_rs.src import types +from keras_rs.src.metrics.utils import standardize_call_inputs_ranks +from keras_rs.src.api_export import keras_rs_export + +@keras_rs_export("keras_rs.losses.ListMLELoss") +class ListMLELoss(keras.losses.Loss): + """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking. + + ListMLE loss is a listwise ranking loss that maximizes the likelihood of + the ground truth ranking. It works by: + 1. Sorting items by their relevance scores (labels) + 2. Computing the probability of observing this ranking given the + predicted scores + 3. Maximizing this likelihood (minimizing negative log-likelihood) + + The loss is computed as the negative log-likelihood of the ground truth + ranking given the predicted scores: + + ``` + loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i))) + ``` + + where s_i is the predicted score for item i in the sorted order. + + Args: + temperature: Temperature parameter for scaling logits. Higher values + make the probability distribution more uniform. Defaults to 1.0. + reduction: Type of reduction to apply to the loss. In almost all cases + this should be `"sum_over_batch_size"`. Supported options are + `"sum"`, `"sum_over_batch_size"`, `"mean"`, + `"mean_with_sample_weight"` or `None`. Defaults to + `"sum_over_batch_size"`. + name: Optional name for the loss instance. + dtype: The dtype of the loss's computations. Defaults to `None`. + + Examples: + ```python + # Basic usage + loss_fn = ListMLELoss() + + # With temperature scaling + loss_fn = ListMLELoss(temperature=0.5) + + # Example with synthetic data + y_true = [[3, 2, 1, 0]] # Relevance scores + y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores + loss = loss_fn(y_true, y_pred) + ``` + """ + + def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if temperature <= 0.0: + raise ValueError( + f"`temperature` should be a positive float. Received: " + f"`temperature` = {temperature}." + ) + + self.temperature = temperature + self._epsilon = 1e-10 + + + def compute_unreduced_loss( + self, + labels: types.Tensor, + logits: types.Tensor, + mask: types.Tensor | None = None, + ) -> tuple[types.Tensor, types.Tensor]: + """Compute the unreduced ListMLE loss. + + Args: + labels: Ground truth relevance scores of + shape [batch_size,list_size]. + logits: Predicted scores of shape [batch_size, list_size]. + mask: Optional mask of shape [batch_size, list_size]. + + Returns: + Tuple of (losses, weights) where losses has shape [batch_size, 1] + and weights has the same shape. + """ + + valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype)) + + if mask is not None: + valid_mask = ops.logical_and(valid_mask, ops.cast(mask, dtype="bool")) + + num_valid_items = ops.sum(ops.cast(valid_mask, dtype=labels.dtype), + axis=1, keepdims=True) + + batch_has_valid_items = ops.greater(num_valid_items, 0.0) + + labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9)) + logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9)) + + sorted_indices = ops.argsort(-labels_for_sorting, axis=-1) + + sorted_logits = ops.take_along_axis(logits_masked, sorted_indices, axis=-1) + sorted_valid_mask = ops.take_along_axis(valid_mask, sorted_indices, axis=-1) + + sorted_logits = ops.divide( + sorted_logits, + ops.cast(self.temperature, dtype=sorted_logits.dtype) + ) + + valid_logits_for_max = ops.where(sorted_valid_mask, sorted_logits, + ops.full_like(sorted_logits, -1e9)) + raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True) + + raw_max = ops.where(batch_has_valid_items, raw_max, ops.zeros_like(raw_max)) + sorted_logits = sorted_logits - raw_max + + + exp_logits = ops.exp(sorted_logits) + + exp_logits = ops.where(sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits)) + + reversed_exp = ops.flip(exp_logits, axis=1) + reversed_cumsum = ops.cumsum(reversed_exp, axis=1) + cumsum_from_right = ops.flip(reversed_cumsum, axis=1) + + + log_normalizers = ops.log(cumsum_from_right + self._epsilon) + log_probs = sorted_logits - log_normalizers + + log_probs = ops.where(sorted_valid_mask, log_probs, ops.zeros_like(log_probs)) + + + negative_log_likelihood = -ops.sum(log_probs, axis=1, keepdims=True) + + negative_log_likelihood = ops.where(batch_has_valid_items, negative_log_likelihood, + ops.zeros_like(negative_log_likelihood)) + + weights = ops.ones_like(negative_log_likelihood) + + return negative_log_likelihood, weights + + def call( + self, + y_true: types.Tensor, + y_pred: types.Tensor, + ) -> types.Tensor: + """Compute the ListMLE loss. + + Args: + y_true: tensor or dict. Ground truth values. If tensor, of shape + `(list_size)` for unbatched inputs or `(batch_size, list_size)` + for batched inputs. If an item has a label of -1, it is ignored + in loss computation. If it is a dictionary, it should have two + keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore + elements in loss computation. + y_pred: tensor. The predicted values, of shape `(list_size)` for + unbatched inputs or `(batch_size, list_size)` for batched + inputs. Should be of the same shape as `y_true`. + + Returns: + The loss tensor of shape [batch_size]. + """ + mask = None + if isinstance(y_true, dict): + if "labels" not in y_true: + raise ValueError( + '`"labels"` should be present in `y_true`. Received: ' + f"`y_true` = {y_true}" + ) + + mask = y_true.get("mask", None) + y_true = y_true["labels"] + + y_true = ops.convert_to_tensor(y_true) + y_pred = ops.convert_to_tensor(y_pred) + if mask is not None: + mask = ops.convert_to_tensor(mask) + + y_true, y_pred, mask, _ = standardize_call_inputs_ranks( + y_true, y_pred, mask + ) + + losses, weights = self.compute_unreduced_loss( + labels=y_true, logits=y_pred, mask=mask + ) + losses = ops.multiply(losses, weights) + losses = ops.squeeze(losses, axis=-1) + return losses + + + def get_config(self) -> dict[str, Any]: + config: dict[str, Any] = super().get_config() + config.update({"temperature": self.temperature}) + return config diff --git a/keras_rs/src/losses/list_mle_loss_test.py b/keras_rs/src/losses/list_mle_loss_test.py new file mode 100644 index 0000000..c042bcc --- /dev/null +++ b/keras_rs/src/losses/list_mle_loss_test.py @@ -0,0 +1,88 @@ +import keras +from absl.testing import parameterized +from keras import ops +from keras.losses import deserialize +from keras.losses import serialize + +from keras_rs.src import testing +from keras_rs.src.losses.list_mle_loss import ListMLELoss + + +class ListMLELossTest(testing.TestCase, parameterized.TestCase): + def setUp(self): + self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8]) + self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0]) + + self.batched_scores = ops.array( + [[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]] + ) + self.batched_labels = ops.array( + [[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]] + ) + self.expected_output = ops.array([6.865693, 3.088192]) + + def test_unbatched_input(self): + loss = ListMLELoss(reduction="none") + output = loss( + y_true=self.unbatched_labels, y_pred=self.unbatched_scores + ) + self.assertEqual(output.shape, (1,)) + self.assertTrue(ops.convert_to_numpy(output[0]) > 0) + self.assertAllClose(output, [self.expected_output[0]], atol=1e-5) + + def test_batched_input(self): + loss = ListMLELoss(reduction="none") + output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) + self.assertEqual(output.shape, (2,)) + self.assertTrue(ops.convert_to_numpy(output[0]) > 0) + self.assertTrue(ops.convert_to_numpy(output[1]) > 0) + self.assertAllClose(output, self.expected_output, atol=1e-5) + + def test_temperature(self): + + loss_temp = ListMLELoss(temperature=0.5, reduction="none") + output_temp = loss_temp(y_true=self.batched_labels, y_pred=self.batched_scores) + + self.assertAllClose(output_temp,[10.969891,2.1283305],atol=1e-5, + ) + + def test_invalid_input_rank(self): + rank_1_input = ops.ones((2, 3, 4)) + + loss = ListMLELoss() + with self.assertRaises(ValueError): + loss(y_true=rank_1_input, y_pred=rank_1_input) + + def test_loss_reduction(self): + loss = ListMLELoss(reduction="sum_over_batch_size") + output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) + + self.assertAlmostEqual(ops.convert_to_numpy(output), 4.9769425, places=5) + + def test_scalar_sample_weight(self): + sample_weight = ops.array(5.0) + loss = ListMLELoss(reduction="none") + + output = loss( + y_true=self.batched_labels, + y_pred=self.batched_scores, + sample_weight=sample_weight, + ) + + self.assertAllClose(output, self.expected_output * sample_weight, atol=1e-5) + + def test_model_fit(self): + inputs = keras.Input(shape=(20,), dtype="float32") + outputs = keras.layers.Dense(5)(inputs) + model = keras.Model(inputs=inputs, outputs=outputs) + + model.compile(loss=ListMLELoss(), optimizer="adam") + model.fit( + x=keras.random.normal((2, 20)), + y=keras.random.randint((2, 5), minval=0, maxval=2), + ) + + def test_serialization(self): + loss = ListMLELoss(temperature=0.8) + restored = deserialize(serialize(loss)) + self.assertDictEqual(loss.get_config(), restored.get_config())