@@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
6161 num_classes = 2 if self ._type == "binary" else y_pred .size (1 )
6262 if self ._type == "multiclass" and y .max () + 1 > num_classes :
6363 raise ValueError (
64- f"y_pred contains less classes than y. Number of predicted classes is { num_classes } "
65- f" and element in y has invalid class = { y .max ().item () + 1 } ."
64+ f"y_pred contains fewer classes than y. Number of classes in the prediction is { num_classes } "
65+ f" and an element in y has invalid class = { y .max ().item () + 1 } ."
6666 )
6767 y = y .view (- 1 )
6868 if self ._type == "binary" and self ._average is False :
@@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
8686
8787 @reinit__is_reduced
8888 def reset (self ) -> None :
89- # `numerator`, `denominator` and `weight` are three variables chosen to be abstract
90- # representatives of the ones that are measured for cases with different `average` parameters.
91- # `weight` is only used when `average='weighted'`. Actual value of these three variables is
92- # as follows.
93- #
94- # average='samples':
95- # numerator (torch.Tensor): sum of metric value for samples
96- # denominator (int): number of samples
97- #
98- # average='weighted':
99- # numerator (torch.Tensor): number of true positives per class/label
100- # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
101- # positives per class/label
102- # weight (torch.Tensor): number of actual positives per class
103- #
104- # average='micro':
105- # numerator (torch.Tensor): sum of number of true positives for classes/labels
106- # denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
107- # for classes/labels
108- #
109- # average='macro' or boolean or None:
110- # numerator (torch.Tensor): number of true positives per class/label
111- # denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
112- # positives per class/label
89+ """
90+ `numerator`, `denominator` and `weight` are three variables chosen to be abstract
91+ representatives of the ones that are measured for cases with different `average` parameters.
92+ `weight` is only used when `average='weighted'`. Actual value of these three variables is
93+ as follows.
94+
95+ average='samples':
96+ numerator (torch.Tensor): sum of metric value for samples
97+ denominator (int): number of samples
98+
99+ average='weighted':
100+ numerator (torch.Tensor): number of true positives per class/label
101+ denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
102+ class/label.
103+ weight (torch.Tensor): number of actual positives per class
104+
105+ average='micro':
106+ numerator (torch.Tensor): sum of number of true positives for classes/labels
107+ denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
108+ classes/labels.
109+
110+ average='macro' or boolean or None:
111+ numerator (torch.Tensor): number of true positives per class/label
112+ denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
113+ class/label.
114+ """
113115
114116 self ._numerator : Union [int , torch .Tensor ] = 0
115117 self ._denominator : Union [int , torch .Tensor ] = 0
@@ -120,16 +122,20 @@ def reset(self) -> None:
120122
121123 @sync_all_reduce ("_numerator" , "_denominator" )
122124 def compute (self ) -> Union [torch .Tensor , float ]:
123- # Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
124- #
125- # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
126- #
127- # wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
128- # for the `macro` one. :math:`C` is the number of classes/labels.
129- #
130- # Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
131- #
132- # .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
125+ r"""
126+ Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
127+
128+ .. math::
129+ \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
130+
131+ wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
132+ for the `macro` one. :math:`C` is the number of classes/labels.
133+
134+ Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
135+
136+ .. math::
137+ \text{Precision/Recall} = \frac{ numerator }{ denominator }
138+ """
133139
134140 if not self ._updated :
135141 raise NotComputableError (
@@ -367,6 +373,33 @@ def thresholded_output_transform(output):
367373
368374 @reinit__is_reduced
369375 def update (self , output : Sequence [torch .Tensor ]) -> None :
376+ r"""
377+ Update the metric state using prediction and target.
378+
379+ Args:
380+ output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
381+ dimension, `...` for possible additional dimensions and C for class dimension.
382+
383+ .. list-table::
384+ :widths: 20 10 10 10
385+ :header-rows: 1
386+
387+ * - Output member\\Data type
388+ - Binary
389+ - Multiclass
390+ - Multilabel
391+ * - y_pred
392+ - (N, ...)
393+ - (N, C, ...)
394+ - (N, C, ...)
395+ * - y
396+ - (N, ...)
397+ - (N, ...)
398+ - (N, C, ...)
399+
400+ For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
401+ data, y_pred and y should consist of probabilities and integers respectively.
402+ """
370403 self ._check_shape (output )
371404 self ._check_type (output )
372405 y_pred , y , correct = self ._prepare_output (output )
0 commit comments