- 
          
 - 
                Notifications
    
You must be signed in to change notification settings  - Fork 657
 
Added the ndcg metric [WIP] #2632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Draft
      
      
            kamalojasv181
  wants to merge
  38
  commits into
  pytorch:master
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
kamalojasv181:ndcg
  
      
      
   
  
    
  
  
  
 
  
      
    base: master
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
      
        
          +394
        
        
          −0
        
        
          
        
      
    
  
  
     Draft
                    Changes from 13 commits
      Commits
    
    
            Show all changes
          
          
            38 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      cd6ec7f
              
                Added the ndcg metric [WIP]
              
              
                kamalojasv181 6421879
              
                Merge branch 'master' into ndcg
              
              
                kamalojasv181 4535af1
              
                added GPU support, corrected mypy errors, and minor fixes
              
              
                kamalojasv181 eb73c99
              
                Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
              
              
                kamalojasv181 70d06e5
              
                Incorporated the suggested changes
              
              
                kamalojasv181 6a86f5f
              
                Fixed mypy error
              
              
                kamalojasv181 7b7ed6f
              
                Fixed bugs in NDCG and added tests for output and reset
              
              
                kamalojasv181 2c87ee1
              
                Fixed mypy error
              
              
                kamalojasv181 f4c628a
              
                Added the exponential form on https://en.wikipedia.org/wiki/Discounte…
              
              
                kamalojasv181 e72b59e
              
                Corrected true, pred order and corresponding tests
              
              
                kamalojasv181 ef63d85
              
                Added ties case, exponential tests, log_base tests, corresponding tes…
              
              
                kamalojasv181 115501b
              
                Added GPU check on top
              
              
                kamalojasv181 189b579
              
                Put tensors on GPU inside the function to pervent error
              
              
                kamalojasv181 c509456
              
                Improved tests and minor bugfixes
              
              
                kamalojasv181 84900f0
              
                Removed device hyperparam from _ndcg_smaple_scores
              
              
                kamalojasv181 9bfc06e
              
                Skipped GPU tests for CPU only systems
              
              
                kamalojasv181 477e096
              
                Changed Error message
              
              
                kamalojasv181 44329d7
              
                Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
              
              
                kamalojasv181 5ba7fb7
              
                Merge branch 'pytorch:master' into ndcg
              
              
                kamalojasv181 ac800ff
              
                Made tests randomised from deterministic and introduced ignore_ties_f…
              
              
                kamalojasv181 691e89a
              
                Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
              
              
                kamalojasv181 962bcef
              
                Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
              
              
                kamalojasv181 79979cc
              
                Merge branch 'pytorch:master' into ndcg
              
              
                kamalojasv181 0c1d6fd
              
                Changed test name to test_output_cuda from test_output_gpu
              
              
                kamalojasv181 85cdcaf
              
                Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
              
              
                kamalojasv181 fdf7877
              
                Merge branch 'master' into ndcg
              
              
                kamalojasv181 2931d20
              
                Changed variable names to replacement and ignore_ties and removed red…
              
              
                kamalojasv181 c308e41
              
                Changed variable names to replacement and ignore_ties and removed red…
              
              
                kamalojasv181 95ede6c
              
                Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
              
              
                kamalojasv181 3a4d2af
              
                Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
              
              
                kamalojasv181 6e66273
              
                Removed redundant test cases and removed the redundant if statement
              
              
                kamalojasv181 eb75afa
              
                Added distributed tests, added multiple test cases corresponding to o…
              
              
                kamalojasv181 dcf276d
              
                Made the tests wsork on in ddp configuration
              
              
                kamalojasv181 cb273e7
              
                Merge branch 'master' of https://github.com/pytorch/ignite into ndcg
              
              
                kamalojasv181 b0f449b
              
                Merge branch 'pytorch:master' into ndcg
              
              
                kamalojasv181 388db23
              
                Merge branch 'ndcg' of https://github.com/kamalojasv181/ignite into ndcg
              
              
                kamalojasv181 b3b6b28
              
                Merge branch 'pytorch:master' into ndcg
              
              
                kamalojasv181 6dcf3b2
              
                Returning tuple of two tensors instead of tuple of list of tensors
              
              
                kamalojasv181 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from ignite.metrics.recsys.ndcg import NDCG | ||
| 
     | 
||
| __all__ = [ | ||
| "NDCG", | ||
| ] | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| from typing import Callable, Optional, Sequence, Union | ||
| 
     | 
||
| import torch | ||
| 
     | 
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.metric import Metric | ||
| 
     | 
||
| __all__ = ["NDCG"] | ||
| 
     | 
||
| 
     | 
||
| def _tie_averaged_dcg( | ||
| y_pred: torch.Tensor, | ||
| y_true: torch.Tensor, | ||
| discount_cumsum: torch.Tensor, | ||
| device: Union[str, torch.device] = torch.device("cpu"), | ||
| ) -> torch.Tensor: | ||
| 
     | 
||
| _, inv, counts = torch.unique(-y_pred, return_inverse=True, return_counts=True) | ||
| ranked = torch.zeros(counts.shape[0]).to(device) | ||
| ranked.index_put_([inv], y_true, accumulate=True) | ||
| ranked /= counts | ||
| groups = torch.cumsum(counts, dim=-1) - 1 | ||
| discount_sums = torch.empty(counts.shape[0]).to(device) | ||
| discount_sums[0] = discount_cumsum[groups[0]] | ||
| discount_sums[1:] = torch.diff(discount_cumsum[groups]) | ||
| 
     | 
||
| return torch.sum(torch.mul(ranked, discount_sums)) | ||
| 
     | 
||
| 
     | 
||
| def _dcg_sample_scores( | ||
| y_pred: torch.Tensor, | ||
| y_true: torch.Tensor, | ||
| k: Optional[int] = None, | ||
| log_base: Union[int, float] = 2, | ||
| ignore_ties: bool = False, | ||
| device: Union[str, torch.device] = torch.device("cpu"), | ||
| ) -> torch.Tensor: | ||
| 
     | 
||
| discount = torch.log(torch.tensor(log_base)) / torch.log(torch.arange(y_true.shape[1]) + 2) | ||
| discount = discount.to(device) | ||
| 
     | 
||
| if k is not None: | ||
| discount[k:] = 0.0 | ||
| 
     | 
||
| if ignore_ties: | ||
| ranking = torch.argsort(y_pred, descending=True) | ||
| ranked = y_true[torch.arange(ranking.shape[0]).reshape(-1, 1), ranking].to(device) | ||
| discounted_gains = torch.mm(ranked, discount.reshape(-1, 1)) | ||
| 
     | 
||
| else: | ||
| discount_cumsum = torch.cumsum(discount, dim=-1) | ||
| discounted_gains = torch.tensor( | ||
| [_tie_averaged_dcg(y_p, y_t, discount_cumsum, device) for y_p, y_t in zip(y_pred, y_true)], device=device | ||
| ) | ||
| 
     | 
||
| return discounted_gains | ||
| 
     | 
||
| 
     | 
||
| def _ndcg_sample_scores( | ||
| y_pred: torch.Tensor, | ||
| y_true: torch.Tensor, | ||
| k: Optional[int] = None, | ||
| log_base: Union[int, float] = 2, | ||
| ignore_ties: bool = False, | ||
| device: Union[str, torch.device] = torch.device("cpu"), | ||
| ) -> torch.Tensor: | ||
| 
     | 
||
| gain = _dcg_sample_scores(y_pred, y_true, k=k, log_base=log_base, ignore_ties=ignore_ties, device=device) | ||
| if not ignore_ties: | ||
| gain = gain.unsqueeze(dim=-1) | ||
| normalizing_gain = _dcg_sample_scores(y_true, y_true, k=k, log_base=log_base, ignore_ties=True, device=device) | ||
| all_relevant = normalizing_gain != 0 | ||
| normalized_gain = gain[all_relevant] / normalizing_gain[all_relevant] | ||
| return normalized_gain | ||
| 
     | 
||
| 
     | 
||
| class NDCG(Metric): | ||
| def __init__( | ||
| self, | ||
| output_transform: Callable = lambda x: x, | ||
| device: Union[str, torch.device] = torch.device("cpu"), | ||
| k: Optional[int] = None, | ||
| log_base: Union[int, float] = 2, | ||
| exponential: bool = False, | ||
| ignore_ties: bool = False, | ||
| ): | ||
| 
     | 
||
| assert log_base != 1 or log_base <= 0, f"Illegal value {log_base} for log_base" | ||
| self.log_base = log_base | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| self.k = k | ||
| self.exponential = exponential | ||
| super(NDCG, self).__init__(output_transform=output_transform, device=device) | ||
| self.ignore_ties = ignore_ties | ||
| 
     | 
||
| def reset(self) -> None: | ||
| 
     | 
||
| self.num_examples = 0 | ||
| self.ndcg = torch.tensor(0.0, device=self._device) | ||
| 
     | 
||
| def update(self, output: Sequence[torch.Tensor]) -> None: | ||
| 
     | 
||
| y_pred, y_true = output[0].detach(), output[1].detach() | ||
| 
     | 
||
| if self.exponential: | ||
| y_true = 2 ** y_true - 1 | ||
| 
     | 
||
| gain = _ndcg_sample_scores(y_pred, y_true, k=self.k, log_base=self.log_base, device=self._device) | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| self.ndcg += torch.sum(gain) | ||
| self.num_examples += y_pred.shape[0] | ||
| 
     | 
||
| def compute(self) -> float: | ||
| if self.num_examples == 0: | ||
| raise NotComputableError("NGCD must have at least one example before it can be computed.") | ||
| 
     | 
||
| return (self.ndcg / self.num_examples).item() | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,187 @@ | ||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from sklearn.metrics import ndcg_score | ||
| from sklearn.metrics._ranking import _dcg_sample_scores | ||
| 
     | 
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.recsys.ndcg import NDCG | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.parametrize( | ||
| "y_pred, y_true", | ||
| [ | ||
| (torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
| ( | ||
| torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [3.7, 4.8, 3.9, 4.3, 4.9], [3.7, 4.8, 3.9, 4.3, 4.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.2, 4.5, 8.9, 5.6, 7.2], [2.9, 5.6, 3.8, 7.9, 6.2]]), | ||
| ), | ||
| (torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| ( | ||
| torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), | ||
| ), | ||
| ], | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| ) | ||
| @pytest.mark.parametrize("k", [None, 2, 3]) | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
          
            Show resolved
            Hide resolved
         | 
||
| def test_output_cpu(y_pred, y_true, k): | ||
| 
     | 
||
| device = "cpu" | ||
| 
     | 
||
| ndcg = NDCG(k=k, device=device) | ||
| ndcg.update([y_pred, y_true]) | ||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score(y_true.numpy(), y_pred.numpy(), k=k) | ||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") | ||
| @pytest.mark.parametrize( | ||
| "y_pred, y_true", | ||
| [ | ||
| (torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
| ( | ||
| torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [3.7, 4.8, 3.9, 4.3, 4.9], [3.7, 4.8, 3.9, 4.3, 4.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.2, 4.5, 8.9, 5.6, 7.2], [2.9, 5.6, 3.8, 7.9, 6.2]]), | ||
| ), | ||
| (torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| ( | ||
| torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("k", [None, 2, 3]) | ||
| def test_output_gpu(y_pred, y_true, k): | ||
| 
     | 
||
| device = "cuda" | ||
| y_pred = y_pred.to(device) | ||
| y_true = y_true.to(device) | ||
| ndcg = NDCG(k=k, device=device) | ||
| ndcg.update([y_pred, y_true]) | ||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score(y_true.cpu().numpy(), y_pred.cpu().numpy(), k=k) | ||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
| 
     | 
||
| 
     | 
||
| def test_reset(): | ||
| 
     | 
||
| y_true = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) | ||
| y_pred = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) | ||
| ndcg = NDCG() | ||
| ndcg.update([y_pred, y_true]) | ||
| ndcg.reset() | ||
| 
     | 
||
| with pytest.raises(NotComputableError, match=r"NGCD must have at least one example before it can be computed."): | ||
| ndcg.compute() | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.parametrize( | ||
| "y_pred, y_true", | ||
| [ | ||
| (torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
| ( | ||
| torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [3.7, 4.8, 3.9, 4.3, 4.9], [3.7, 4.8, 3.9, 4.3, 4.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.2, 4.5, 8.9, 5.6, 7.2], [2.9, 5.6, 3.8, 7.9, 6.2]]), | ||
| ), | ||
| (torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| ( | ||
| torch.tensor([[3.7, 3.7, 3.7, 3.7, 3.7], [3.7, 3.7, 3.7, 3.7, 3.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0, 5.0]]), | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("k", [None, 2, 3]) | ||
| def test_exponential(y_pred, y_true, k): | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| device = "cpu" | ||
| 
     | 
||
| ndcg = NDCG(k=k, device=device, exponential=True) | ||
| ndcg.update([y_pred, y_true]) | ||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score(2 ** y_true.numpy() - 1, y_pred.numpy(), k=k) | ||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.parametrize( | ||
| "y_pred, y_true", | ||
| [ | ||
| (torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
| ( | ||
| torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [3.7, 4.8, 3.9, 4.3, 4.9], [3.7, 4.8, 3.9, 4.3, 4.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.2, 4.5, 8.9, 5.6, 7.2], [2.9, 5.6, 3.8, 7.9, 6.2]]), | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("k", [None, 2, 3]) | ||
| def test_output_cpu_ignore_ties(y_pred, y_true, k): | ||
| 
     | 
||
| device = "cpu" | ||
| 
     | 
||
| ndcg = NDCG(k=k, device=device, ignore_ties=True) | ||
| ndcg.update([y_pred, y_true]) | ||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score(y_true.numpy(), y_pred.numpy(), k=k) | ||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") | ||
| @pytest.mark.parametrize( | ||
| "y_pred, y_true", | ||
| [ | ||
| (torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]), torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])), | ||
| (torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]), torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]])), | ||
| ( | ||
| torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5], [3.7, 4.8, 3.9, 4.3, 4.9], [3.7, 4.8, 3.9, 4.3, 4.9]]), | ||
| torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0], [1.2, 4.5, 8.9, 5.6, 7.2], [2.9, 5.6, 3.8, 7.9, 6.2]]), | ||
| ), | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize("k", [None, 2, 3]) | ||
| def test_output_gpu_ignore_ties(y_pred, y_true, k): | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| device = "cuda" | ||
| y_pred = y_pred.to(device) | ||
| y_true = y_true.to(device) | ||
| ndcg = NDCG(k=k, device=device, ignore_ties=True) | ||
| ndcg.update([y_pred, y_true]) | ||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score(y_true.cpu().numpy(), y_pred.cpu().numpy(), k=k) | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
| 
     | 
||
| 
     | 
||
| @pytest.mark.parametrize("log_base", [2, 3, 10]) | ||
| def test_log_base(log_base): | ||
| def _ndcg_sample_scores(y_true, y_score, k=None, ignore_ties=False): | ||
| 
     | 
||
| gain = _dcg_sample_scores(y_true, y_score, k, ignore_ties=ignore_ties) | ||
| normalizing_gain = _dcg_sample_scores(y_true, y_true, k, ignore_ties=True) | ||
| all_irrelevant = normalizing_gain == 0 | ||
| gain[all_irrelevant] = 0 | ||
| gain[~all_irrelevant] /= normalizing_gain[~all_irrelevant] | ||
| return gain | ||
| 
     | 
||
| def ndcg_score_with_log_base(y_true, y_score, *, k=None, sample_weight=None, ignore_ties=False, log_base=2): | ||
| 
     | 
||
| gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) | ||
| return np.average(gain, weights=sample_weight) | ||
                
      
                  vfdev-5 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||
| 
     | 
||
| y_true = torch.tensor([[3.7, 4.8, 3.9, 4.3, 4.9]]) | ||
| y_pred = torch.tensor([[2.9, 5.6, 3.8, 7.9, 6.2]]) | ||
| 
     | 
||
| ndcg = NDCG(log_base=log_base) | ||
| ndcg.update([y_pred, y_true]) | ||
| 
     | 
||
| result_ignite = ndcg.compute() | ||
| result_sklearn = ndcg_score_with_log_base(y_true.numpy(), y_pred.numpy(), log_base=log_base) | ||
| 
     | 
||
| np.testing.assert_allclose(np.array(result_ignite), result_sklearn, rtol=2e-7) | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, there is no way to make it vectorized == without for-loop ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I havent checked yet. For now I have added this implementation. It's a TODO