Skip to content

Commit f4192cb

Browse files
lingvo-botcopybara-github
authored andcommitted
Add group pairwise AUC metric for a list of predictions.
PiperOrigin-RevId: 480180039
1 parent f5bbf09 commit f4192cb

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

lingvo/core/metrics.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,3 +728,62 @@ def Summary(self, name):
728728
def _CreateSummary(self, name):
729729
"""Returns a tf.Summary for this metric."""
730730
raise NotImplementedError()
731+
732+
733+
class GroupPairAUCMetric(AUCMetric):
734+
"""Compute the AUC score for all pairs extracted from each group of items.
735+
736+
For each group of items, the metric extracts all pairs with different
737+
target values. For each pair (i, j), the metric computes the binary
738+
classification AUC where the `label = 1 if target[i] > target[j] else 0` and
739+
`prob = sigmoid(logits[i] - logits[j])`.
740+
741+
To prevent generating pairs across groups, an additional arg `group_ids` is
742+
required, which is a list of ints that specifies the group_id of each item.
743+
744+
In addition, in order to achieve streaming computation, items from the same
745+
group need to form continuous chunks,
746+
e.g. group_ids = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].
747+
748+
In the case of [0, 0, 1, 1, 1, 0, 0, 2, 2, 2, 2], the second chunk of 0s will
749+
be treated as a separate 3rd group rather than part of the 1st group.
750+
"""
751+
752+
def UpdateRaw(self, group_ids, target, logits, weight=None):
753+
"""Updates the metrics.
754+
755+
Args:
756+
group_ids: An array to specify the group identity.
757+
target: An array to specify the groundtruth float values.
758+
logits: An array to specify the raw prediction logits.
759+
weight: An array to specify the sample weight for the auc computation.
760+
"""
761+
762+
assert self._samples <= 0
763+
764+
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))
765+
766+
def _ProcessChunk(s, e):
767+
for i in range(s, e):
768+
for j in range(i + 1, e):
769+
if target[i] != target[j]:
770+
pair_label = 1 if target[i] > target[j] else 0
771+
pair_prob = sigmoid(logits[i] - logits[j])
772+
self._label.append(pair_label)
773+
self._prob.append(pair_prob)
774+
if weight:
775+
self._weight.append(min(1.0, weight[i] + weight[j]))
776+
else:
777+
self._weight.append(1.0)
778+
779+
s, e = 0, 1
780+
while e <= len(target):
781+
# Find the end of a chunk
782+
if e == len(target) or group_ids[e] != group_ids[s]:
783+
# Process the current chunk [s:e]
784+
_ProcessChunk(s, e)
785+
786+
# Start a new chunk by setting `s` to `e`
787+
s = e
788+
# Increment `e` by 1.
789+
e += 1

lingvo/core/metrics_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,32 @@ def testMultiClassAUCMetric(self):
246246
# Verify average AUC value.
247247
self.assertAllClose(0.722222222, m.value)
248248

249+
def testGroupPairAUCMetric(self):
250+
if not metrics.HAS_SKLEARN:
251+
self.skipTest('sklearn is not installed.')
252+
pair_m = metrics.AUCMetric()
253+
group_m = metrics.GroupPairAUCMetric()
254+
group_ids = [0, 0, 0, 1, 1, 1, 2, 2]
255+
target = np.random.rand(8).tolist()
256+
logits = np.random.rand(8).tolist()
257+
weight = [1.0] * 8
258+
group_m.UpdateRaw(
259+
group_ids=group_ids, target=target, logits=logits, weight=weight)
260+
261+
sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x))
262+
left, right = 0, 0
263+
while right < len(group_ids):
264+
while right < len(group_ids) and group_ids[right] == group_ids[left]:
265+
right += 1
266+
for i in range(left, right):
267+
for j in range(i + 1, right):
268+
if group_ids[i] == group_ids[j] and target[i] != target[j]:
269+
pair_m.Update(
270+
label=[1 if target[i] > target[j] else 0],
271+
prob=[sigmoid(logits[i] - logits[j])],
272+
weight=[min(1.0, weight[i] + weight[j])])
273+
left = right
274+
self.assertEqual(pair_m.value, group_m.value)
249275

250276
if __name__ == '__main__':
251277
test_utils.main()

0 commit comments

Comments
 (0)