From 2834862824625ba3360b200ff611d71783b1e397 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 07:04:19 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method=20`Rec?= =?UTF-8?q?all.=5Fcompute=5Fconfusion=5Fmatrix`=20by=2051%=20Here=20is=20t?= =?UTF-8?q?he=20optimized=20version.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/metrics/recall.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py index b3586ff7d..0de53a183 100644 --- a/supervision/metrics/recall.py +++ b/supervision/metrics/recall.py @@ -334,27 +334,31 @@ class ids. num_thresholds = sorted_matches.shape[1] num_classes = unique_classes.shape[0] + # Initialize confusion matrix confusion_matrix = np.zeros((num_classes, num_thresholds, 3)) + + # Vectorized approach to handle the class-based confusion matrix updates for class_idx, class_id in enumerate(unique_classes): - is_class = sorted_prediction_class_ids == class_id + class_mask = sorted_prediction_class_ids == class_id num_true = class_counts[class_idx] - num_predictions = is_class.sum() + num_predictions = class_mask.sum() if num_predictions == 0: - true_positives = np.zeros(num_thresholds) - false_positives = np.zeros(num_thresholds) - false_negatives = np.full(num_thresholds, num_true) + confusion_matrix[class_idx, :, 0] = 0 # true_positives + confusion_matrix[class_idx, :, 1] = 0 # false_positives + confusion_matrix[class_idx, :, 2] = num_true # false_negatives elif num_true == 0: - true_positives = np.zeros(num_thresholds) - false_positives = np.full(num_thresholds, num_predictions) - false_negatives = np.zeros(num_thresholds) + confusion_matrix[class_idx, :, 0] = 0 # true_positives + confusion_matrix[class_idx, :, 1] = num_predictions # false_positives + confusion_matrix[class_idx, :, 2] = 0 # false_negatives else: - true_positives = sorted_matches[is_class].sum(0) - false_positives = (1 - sorted_matches[is_class]).sum(0) + true_positives = sorted_matches[class_mask].sum(axis=0) + false_positives = class_mask.sum() - true_positives false_negatives = num_true - true_positives - confusion_matrix[class_idx] = np.stack( - [true_positives, false_positives, false_negatives], axis=1 - ) + + confusion_matrix[class_idx, :, 0] = true_positives + confusion_matrix[class_idx, :, 1] = false_positives + confusion_matrix[class_idx, :, 2] = false_negatives return confusion_matrix