From c22456d4665995a00ee24e34f43993e7db99074a Mon Sep 17 00:00:00 2001 From: Deepti Hegde Date: Thu, 8 Apr 2021 11:23:42 -0400 Subject: [PATCH 1/3] Update train_rcnn.py --- pointrcnn/tools/train_rcnn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pointrcnn/tools/train_rcnn.py b/pointrcnn/tools/train_rcnn.py index 0dafbdc..556fadc 100644 --- a/pointrcnn/tools/train_rcnn.py +++ b/pointrcnn/tools/train_rcnn.py @@ -69,12 +69,13 @@ def create_dataloader(logger): DATA_PATH = args.root # create dataloader - train_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, split=cfg.TRAIN.SPLIT, mode='TRAIN', + train_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, + split=cfg.TRAIN.SPLIT, classes=cfg.CLASSES,mode='TRAIN', logger=logger, - classes=cfg.CLASSES, npoints_faraway=args.npoints_faraway, + far_points=args.npoints_faraway, rcnn_training_roi_dir=args.rcnn_training_roi_dir, rcnn_training_feature_dir=args.rcnn_training_feature_dir, - gt_database_dir=args.gt_database, with_replace=args.with_replace, + gt_database_dir=None, with_replace=args.with_replace, subsample=args.subsample, shuffle_subsample=args.shuffle_subsample) train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=True, num_workers=args.workers, shuffle=True, collate_fn=train_set.collate_batch, From 3d6161847a4321476146b6d2e1f7320b9a75a5ec Mon Sep 17 00:00:00 2001 From: Deepti Hegde Date: Thu, 8 Apr 2021 11:25:09 -0400 Subject: [PATCH 2/3] Update kitti_rcnn_dataset.py --- pointrcnn/lib/datasets/kitti_rcnn_dataset.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pointrcnn/lib/datasets/kitti_rcnn_dataset.py b/pointrcnn/lib/datasets/kitti_rcnn_dataset.py index 81c26f4..69ddfdd 100644 --- a/pointrcnn/lib/datasets/kitti_rcnn_dataset.py +++ b/pointrcnn/lib/datasets/kitti_rcnn_dataset.py @@ -11,8 +11,9 @@ class KittiRCNNDataset(KittiDataset): def __init__(self, root_dir, npoints=16384, split='train', classes='Car', mode='TRAIN', random_select=True, - logger=None, rcnn_training_roi_dir=None, rcnn_training_feature_dir=None, rcnn_eval_roi_dir=None, - rcnn_eval_feature_dir=None, gt_database_dir=None, far_points=200): + logger=None,far_points=200, rcnn_training_roi_dir=None, rcnn_training_feature_dir=None, rcnn_eval_roi_dir=None, + rcnn_eval_feature_dir=None, gt_database_dir=None, with_replace=False, + subsample=1, shuffle_subsample=None): super().__init__(root_dir=root_dir, split=split) if classes == 'Car': self.classes = ('Background', 'Car') From 9508edfddeb81cbe7c87ffc0d6cdc8a5a4804c1b Mon Sep 17 00:00:00 2001 From: Deepti Hegde Date: Thu, 8 Apr 2021 11:30:29 -0400 Subject: [PATCH 3/3] Update train_rcnn.py --- pointrcnn/tools/train_rcnn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pointrcnn/tools/train_rcnn.py b/pointrcnn/tools/train_rcnn.py index 556fadc..9ccc0a4 100644 --- a/pointrcnn/tools/train_rcnn.py +++ b/pointrcnn/tools/train_rcnn.py @@ -75,7 +75,7 @@ def create_dataloader(logger): far_points=args.npoints_faraway, rcnn_training_roi_dir=args.rcnn_training_roi_dir, rcnn_training_feature_dir=args.rcnn_training_feature_dir, - gt_database_dir=None, with_replace=args.with_replace, + gt_database_dir=args.gt_database, with_replace=args.with_replace, subsample=args.subsample, shuffle_subsample=args.shuffle_subsample) train_loader = DataLoader(train_set, batch_size=args.batch_size, pin_memory=True, num_workers=args.workers, shuffle=True, collate_fn=train_set.collate_batch, @@ -84,8 +84,9 @@ def create_dataloader(logger): if args.train_with_eval: test_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, split=cfg.TRAIN.VAL_SPLIT, mode='EVAL', logger=logger, + far_points=args.npoints_faraway, classes=cfg.CLASSES, - rcnn_eval_roi_dir=args.rcnn_eval_roi_dir, npoints_faraway=args.npoints_faraway, + rcnn_eval_roi_dir=args.rcnn_eval_roi_dir, rcnn_eval_feature_dir=args.rcnn_eval_feature_dir, with_replace=args.with_replace) test_loader = DataLoader(test_set, batch_size=1, shuffle=True, pin_memory=True, num_workers=args.workers, collate_fn=test_set.collate_batch)