From d50fcabf467ebbdfd4c37ec5b554c73d3967fb6c Mon Sep 17 00:00:00 2001 From: suniique Date: Sun, 21 May 2023 17:03:03 +0200 Subject: [PATCH 1/2] add format_results --- examples/mmdet3d_dataset.py | 151 +++++++++++++++++++++++++++++++++++- 1 file changed, 149 insertions(+), 2 deletions(-) diff --git a/examples/mmdet3d_dataset.py b/examples/mmdet3d_dataset.py index ef7e7bd..6a75752 100644 --- a/examples/mmdet3d_dataset.py +++ b/examples/mmdet3d_dataset.py @@ -36,14 +36,19 @@ import os import sys +import tempfile import mmcv import numpy as np from mmdet3d.core.bbox import CameraInstance3DBoxes +from mmdet3d.core.bbox.structures import Box3DMode from mmdet3d.datasets.builder import DATASETS from mmdet3d.datasets.custom_3d import Custom3DDataset from mmdet3d.datasets.pipelines import LoadAnnotations3D +from scalabel.label.typing import Dataset, Frame, Label, Box3D, RLE +from scalabel.label.transforms import mask_to_rle + # Add the root directory of the project to the path. Remove the following two lines # if you have installed shift_dev as a package. root_dir = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) @@ -69,6 +74,7 @@ def __init__( depth_prefix: str = "", backend_type: str = "file", img_to_float32: bool = False, + yaw_offset: float = np.pi / 2, **kwargs ): """Initialize the SHIFT dataset. @@ -84,11 +90,14 @@ def __init__( backend_type (str, optional): The type of the backend. Must be one of ['file', 'zip', 'hdf5']. Defaults to "file". img_to_float32 (bool, optional): Whether to convert the loaded image to float32. Defaults to False. + yaw_offset (float, optional): The yaw offset of the 3D bounding boxes. Defaults to np.pi / 2. This is + used to correct the heading of the 3D bounding boxes. """ self.data_root = data_root self.ann_file = os.path.join(self.data_root, ann_file) self.img_prefix = os.path.join(self.data_root, img_prefix) self.img_to_float32 = img_to_float32 + self.yaw_offset = yaw_offset self.insseg_ann_file = os.path.join(self.data_root, insseg_ann_file) if insseg_ann_file != "" else "" self.depth_prefix = os.path.join(self.data_root, depth_prefix) if depth_prefix != "" else "" @@ -136,7 +145,7 @@ def load_annotations(self, ann_file): box3d = label["box3d"] boxes.append((box2d["x1"], box2d["y1"], box2d["x2"], box2d["y2"])) boxes_3d.append( - box3d["location"] + box3d["dimension"] + [box3d["orientation"][1] + np.pi / 2], # yaw + box3d["location"] + box3d["dimension"] + [box3d["orientation"][1] + self.yaw_offset], # yaw ) labels.append(self.CLASSES.index(label["category"])) track_ids.append(label["id"]) @@ -198,7 +207,7 @@ def get_depth(self, idx): depth_img = self.read_image(depth_filename) depth_img = depth_img.astype(np.float32) depth = depth_img[:, :, 0] * 256 * 256 + depth_img[:, :, 1] * 256 + depth_img[:, :, 2] - depth = depth * DEPTH_FACTOR + depth /= self.DEPTH_FACTOR return depth def get_img_info(self, idx): @@ -267,6 +276,144 @@ def prepare_test_data(self, idx): results["pad_shape"] = img_shape self.pre_pipeline(results) return self.pipeline(results) + + def seg_results2scalabel(self, results, output_dir, seg_key="seg_results") -> str: + """Convert instance segmentation results to Scalabel format (det_insseg_2d.json). + + Args: + results (list[dict]): Testing results of the dataset. + output_dir (str): Output directory of the results in Scalabel + format. + result_key (str): Key of semantic segmentation results in results. + Default: "seg_results". + + Returns: + str: Path of the converted results in Scalabel format. + """ + frames = [] + for result in results: + labels = [] + for seg_label in result[seg_key]: + if "rle" in seg_label: + rle = seg_label["rle"] + label = Label( + id=seg_label["id"], + category=self.CLASSES[seg_label["category_id"]], + rle=RLE( + counts=rle["counts"].decode("utf-8"), + size=(rle["size"][0], rle["size"][1]) + ), + ) + elif "mask" in seg_label: + label = Label( + id=seg_label["id"], + category=self.CLASSES[seg_label["category_id"]], + rle=mask_to_rle(seg_label["mask"]), + ) + labels.append(label) + frame = Frame( + name=result["image_name"], + videoName=result["video_name"], + frameIndex=result["image_name"].split("_")[0], + labels=labels, + ) + frames.append(frame) + ds = Dataset(frames=frames, groups=None, config=None) + + jsonfile = os.path.join(output_dir, "det_insseg_2d.json") + with open(jsonfile, "w") as f: + f.write(ds.json(exclude_unset=True)) + return jsonfile + + def box_results2scalabel(self, results, output_dir, box_key="pts_bbox") -> str: + """Convert 3D object detection results to Scalabel format (det_3d.json). + + Args: + results (list[dict]): Testing results of the dataset. + output_dir (str): Output directory of the results in Scalabel + format. + result_key (str): Key of semantic segmentation results in results. + Default: "seg_results". + + Returns: + str: Path of the output json file. + """ + frames = [] + for result in results: + labels = [] + for label_id, (box, label, score) in enumerate(zip( + result[box_key]["boxes_3d"], + result[box_key]["labels_3d"], + result[box_key]["scores_3d"], + )): + box_cam = box.convert_to(Box3DMode.CAM) + yaw = box_cam.tensor[0, 6] + label = Label( + id=label_id, + category=self.CLASSES[label], + box3d=Box3D( + alpha=0, + location=tuple(box_cam[0, 0:3].tolist()), + dimension=tuple(box_cam[0, 3:6].tolist()), + orientation=(0, yaw - self.yaw_offset, 0) + ), + score=score, + ) + labels.append(label) + frame = Frame( + name=result["image_name"], + videoName=result["video_name"], + frameIndex=result["image_name"].split("_")[0], + labels=labels, + ) + frames.append(frame) + ds = Dataset(frames=frames, groups=None, config=None) + + jsonfile = os.path.join(output_dir, "det_3d.json") + with open(jsonfile, "w") as f: + f.write(ds.json(exclude_unset=True)) + return jsonfile + + def format_results(self, results, jsonfile_prefix=None, **kwargs): + """Format the results to json (standard format for Scalabel evaluation). + + Args: + results (list[dict]): Testing results of the dataset. + jsonfile_prefix (str | None): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Default: None. + **kwargs: Other arguments are ignored. + + Note: + The results are in the format of list[dict], where each dict contains: + - "image_name" (str): Filename of the image. + - "video_name" (str): Filename of the video. + - "pts_bbox" (dict): 3D detection results, which contains: + - "boxes_3d" (BaseInstance3DBoxes): Predicted 3D boxes of shape (N, dim). + - "labels_3d" (torch.Tensor): Predicted labels of shape (N,). + - "scores_3d" (torch.Tensor): Scores of predicted boxes of shape (N,). + - "seg_results" (list): Instance segmentation results, whose item contains either: + - "rle" (dict): RLE encoded segmentation mask, or + - "mask" (np.ndarray): Segmentation mask of shape (H, W). + + Returns: + tuple: (result_files, tmp_dir), result_files is a dict containing + the json filepaths, tmp_dir is the temporal directory created + for saving json files when jsonfile_prefix is not specified. + """ + assert isinstance(results, list), "results must be a list" + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = os.path.join(tmp_dir.name, "results") + else: + tmp_dir = None + + result_files = { + "det_insseg_2d": self.seg_results2scalabel(results, jsonfile_prefix, **kwargs), + "det_3d": self.box_results2scalabel(results, jsonfile_prefix, **kwargs), + } + return result_files, tmp_dir if __name__ == "__main__": From 1daec4f7d57a9fa875f7d4c71a319a3a5893aac4 Mon Sep 17 00:00:00 2001 From: suniique Date: Sun, 21 May 2023 17:09:28 +0200 Subject: [PATCH 2/2] update box convert --- examples/mmdet3d_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/mmdet3d_dataset.py b/examples/mmdet3d_dataset.py index 6a75752..10f0bfc 100644 --- a/examples/mmdet3d_dataset.py +++ b/examples/mmdet3d_dataset.py @@ -255,8 +255,13 @@ def prepare_train_data(self, idx): results = dict(img=img, img_info=img_info, cam2img=self.cam_intrinsic, ann_info=ann_info) if self.depth_prefix != "": results["gt_depth"] = self.get_depth(idx) + # Add lidar2cam matrix for compatibility (e.g., PETR) - results["lidar2cam"] = np.eye(4) + if self.box_mode_3d == Box3DMode.LIDAR: + results["lidar2cam"] = np.array([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) + elif self.box_mode_3d == Box3DMode.DEPTH: + results["depth2cam"] = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + # Set initial shape for mmdet3d pipeline compatibility img_shape = img[0][..., np.newaxis].shape results["img_shape"] = img_shape