Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 38 additions & 43 deletions supervision/metrics/utils/object_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy.typing as npt

from supervision.config import ORIENTED_BOX_COORDINATES
from supervision.detection.core import Detections
from supervision.metrics.core import MetricTarget

if TYPE_CHECKING:
Expand Down Expand Up @@ -54,22 +55,15 @@ def get_bbox_size_category(xyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.int_
xyxy (np.ndarray): The bounding boxes array shaped (N, 4).

Returns:
(np.ndarray) The size category of each bounding box, matching
the enum values of ObjectSizeCategory. Shaped (N,).
(np.ndarray) The size category of each bounding box.
"""
if len(xyxy.shape) != 2 or xyxy.shape[1] != 4:
if xyxy.ndim != 2 or xyxy.shape[1] != 4:
raise ValueError("Bounding boxes must be shaped (N, 4)")

width = xyxy[:, 2] - xyxy[:, 0]
height = xyxy[:, 3] - xyxy[:, 1]
areas = width * height

result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
SM, LG = SIZE_THRESHOLDS
result[areas < SM] = ObjectSizeCategory.SMALL.value
result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
result[areas >= LG] = ObjectSizeCategory.LARGE.value
return result
return _get_size_category_from_areas(areas)


def get_mask_size_category(mask: npt.NDArray[np.bool_]) -> npt.NDArray[np.int_]:
Expand All @@ -80,52 +74,35 @@ def get_mask_size_category(mask: npt.NDArray[np.bool_]) -> npt.NDArray[np.int_]:
mask (np.ndarray): The mask array shaped (N, H, W).

Returns:
(np.ndarray) The size category of each mask, matching
the enum values of ObjectSizeCategory. Shaped (N,).
(np.ndarray) The size category of each mask.
"""
if len(mask.shape) != 3:
if mask.ndim != 3:
raise ValueError("Masks must be shaped (N, H, W)")

areas = np.sum(mask, axis=(1, 2))

result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
SM, LG = SIZE_THRESHOLDS
result[areas < SM] = ObjectSizeCategory.SMALL.value
result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
result[areas >= LG] = ObjectSizeCategory.LARGE.value
return result
areas = np.sum(mask, axis=(1, 2)).astype(np.float32)
return _get_size_category_from_areas(areas)


def get_obb_size_category(xyxyxyxy: npt.NDArray[np.float32]) -> npt.NDArray[np.int_]:
"""
Get the size category of a oriented bounding boxes array.
Get the size category of an oriented bounding boxes array.

Args:
xyxyxyxy (np.ndarray): The bounding boxes array shaped (N, 4, 2).
xyxyxyxy (np.ndarray): The oriented bounding boxes array shaped (N, 4, 2).

Returns:
(np.ndarray) The size category of each bounding box, matching
the enum values of ObjectSizeCategory. Shaped (N,).
(np.ndarray) The size category of each oriented bounding box.
"""
if len(xyxyxyxy.shape) != 3 or xyxyxyxy.shape[1] != 4 or xyxyxyxy.shape[2] != 2:
if xyxyxyxy.ndim != 3 or xyxyxyxy.shape[1] != 4 or xyxyxyxy.shape[2] != 2:
raise ValueError("Oriented bounding boxes must be shaped (N, 4, 2)")

# Shoelace formula
x = xyxyxyxy[:, :, 0]
y = xyxyxyxy[:, :, 1]
x1, x2, x3, x4 = x.T
y1, y2, y3, y4 = y.T
# Shoelace formula using np.roll for a more concise vectorized computation.
areas = 0.5 * np.abs(
(x1 * y2 + x2 * y3 + x3 * y4 + x4 * y1)
- (x2 * y1 + x3 * y2 + x4 * y3 + x1 * y4)
np.sum(x * np.roll(y, -1, axis=1) - np.roll(x, -1, axis=1) * y, axis=1)
)

result = np.full(areas.shape, ObjectSizeCategory.ANY.value)
SM, LG = SIZE_THRESHOLDS
result[areas < SM] = ObjectSizeCategory.SMALL.value
result[(areas >= SM) & (areas < LG)] = ObjectSizeCategory.MEDIUM.value
result[areas >= LG] = ObjectSizeCategory.LARGE.value
return result
return _get_size_category_from_areas(areas)


def get_detection_size_category(
Expand All @@ -135,13 +112,11 @@ def get_detection_size_category(
Get the size category of a detections object.

Args:
xyxyxyxy (np.ndarray): The bounding boxes array shaped (N, 8).
metric_target (MetricTarget): Determines whether boxes, masks or
oriented bounding boxes are used.
detections (Detections): The detection object containing boxes, masks, etc.
metric_target (MetricTarget): Determines whether boxes, masks, or oriented boxes are used.

Returns:
(np.ndarray) The size category of each bounding box, matching
the enum values of ObjectSizeCategory. Shaped (N,).
(np.ndarray) The size category of each detection.
"""
if metric_target == MetricTarget.BOXES:
return get_bbox_size_category(detections.xyxy)
Expand All @@ -156,3 +131,23 @@ def get_detection_size_category(
np.array(detections.data[ORIENTED_BOX_COORDINATES])
)
raise ValueError("Invalid metric type")


# We assume ObjectSizeCategory and SIZE_THRESHOLDS are defined in the internal codebase.
# For example:
# from supervision.metrics.utils.object_size import SIZE_THRESHOLDS
# and ObjectSizeCategory is an enum exposing .SMALL, .MEDIUM, .LARGE, .ANY with a .value attribute.


def _get_size_category_from_areas(
areas: npt.NDArray[np.float32],
) -> npt.NDArray[np.int_]:
SM, LG = SIZE_THRESHOLDS
# All areas that are less than SM become SMALL; if less than LG then MEDIUM; otherwise LARGE.
return np.where(
areas < SM,
ObjectSizeCategory.SMALL.value,
np.where(
areas < LG, ObjectSizeCategory.MEDIUM.value, ObjectSizeCategory.LARGE.value
),
)