@@ -728,3 +728,62 @@ def Summary(self, name):
728
728
def _CreateSummary (self , name ):
729
729
"""Returns a tf.Summary for this metric."""
730
730
raise NotImplementedError ()
731
+
732
+
733
+ class GroupPairAUCMetric (AUCMetric ):
734
+ """Compute the AUC score for all pairs extracted from each group of items.
735
+
736
+ For each group of items, the metric extracts all pairs with different
737
+ target values. For each pair (i, j), the metric computes the binary
738
+ classification AUC where the `label = 1 if target[i] > target[j] else 0` and
739
+ `prob = sigmoid(logits[i] - logits[j])`.
740
+
741
+ To prevent generating pairs across groups, an additional arg `group_ids` is
742
+ required, which is a list of ints that specifies the group_id of each item.
743
+
744
+ In addition, in order to achieve streaming computation, items from the same
745
+ group need to form continuous chunks,
746
+ e.g. group_ids = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].
747
+
748
+ In the case of [0, 0, 1, 1, 1, 0, 0, 2, 2, 2, 2], the second chunk of 0s will
749
+ be treated as a separate 3rd group rather than part of the 1st group.
750
+ """
751
+
752
+ def UpdateRaw (self , group_ids , target , logits , weight = None ):
753
+ """Updates the metrics.
754
+
755
+ Args:
756
+ group_ids: An array to specify the group identity.
757
+ target: An array to specify the groundtruth float values.
758
+ logits: An array to specify the raw prediction logits.
759
+ weight: An array to specify the sample weight for the auc computation.
760
+ """
761
+
762
+ assert self ._samples <= 0
763
+
764
+ sigmoid = lambda x : 1.0 / (1.0 + np .exp (- x ))
765
+
766
+ def _ProcessChunk (s , e ):
767
+ for i in range (s , e ):
768
+ for j in range (i + 1 , e ):
769
+ if target [i ] != target [j ]:
770
+ pair_label = 1 if target [i ] > target [j ] else 0
771
+ pair_prob = sigmoid (logits [i ] - logits [j ])
772
+ self ._label .append (pair_label )
773
+ self ._prob .append (pair_prob )
774
+ if weight :
775
+ self ._weight .append (min (1.0 , weight [i ] + weight [j ]))
776
+ else :
777
+ self ._weight .append (1.0 )
778
+
779
+ s , e = 0 , 1
780
+ while e <= len (target ):
781
+ # Find the end of a chunk
782
+ if e == len (target ) or group_ids [e ] != group_ids [s ]:
783
+ # Process the current chunk [s:e]
784
+ _ProcessChunk (s , e )
785
+
786
+ # Start a new chunk by setting `s` to `e`
787
+ s = e
788
+ # Increment `e` by 1.
789
+ e += 1
0 commit comments