diff --git a/pointrcnn/lib/datasets/kitti_rcnn_dataset.py b/pointrcnn/lib/datasets/kitti_rcnn_dataset.py index 81c26f4..69ddfdd 100644 --- a/pointrcnn/lib/datasets/kitti_rcnn_dataset.py +++ b/pointrcnn/lib/datasets/kitti_rcnn_dataset.py @@ -11,8 +11,9 @@ class KittiRCNNDataset(KittiDataset): def __init__(self, root_dir, npoints=16384, split='train', classes='Car', mode='TRAIN', random_select=True, - logger=None, rcnn_training_roi_dir=None, rcnn_training_feature_dir=None, rcnn_eval_roi_dir=None, - rcnn_eval_feature_dir=None, gt_database_dir=None, far_points=200): + logger=None,far_points=200, rcnn_training_roi_dir=None, rcnn_training_feature_dir=None, rcnn_eval_roi_dir=None, + rcnn_eval_feature_dir=None, gt_database_dir=None, with_replace=False, + subsample=1, shuffle_subsample=None): super().__init__(root_dir=root_dir, split=split) if classes == 'Car': self.classes = ('Background', 'Car') diff --git a/pointrcnn/tools/train_rcnn.py b/pointrcnn/tools/train_rcnn.py index 0dafbdc..9ccc0a4 100644 --- a/pointrcnn/tools/train_rcnn.py +++ b/pointrcnn/tools/train_rcnn.py @@ -69,9 +69,10 @@ def create_dataloader(logger): DATA_PATH = args.root # create dataloader - train_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, split=cfg.TRAIN.SPLIT, mode='TRAIN', + train_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, + split=cfg.TRAIN.SPLIT, classes=cfg.CLASSES,mode='TRAIN', logger=logger, - classes=cfg.CLASSES, npoints_faraway=args.npoints_faraway, + far_points=args.npoints_faraway, rcnn_training_roi_dir=args.rcnn_training_roi_dir, rcnn_training_feature_dir=args.rcnn_training_feature_dir, gt_database_dir=args.gt_database, with_replace=args.with_replace, @@ -83,8 +84,9 @@ def create_dataloader(logger): if args.train_with_eval: test_set = KittiRCNNDataset(root_dir=DATA_PATH, npoints=cfg.RPN.NUM_POINTS, split=cfg.TRAIN.VAL_SPLIT, mode='EVAL', logger=logger, + far_points=args.npoints_faraway, classes=cfg.CLASSES, - rcnn_eval_roi_dir=args.rcnn_eval_roi_dir, npoints_faraway=args.npoints_faraway, + rcnn_eval_roi_dir=args.rcnn_eval_roi_dir, rcnn_eval_feature_dir=args.rcnn_eval_feature_dir, with_replace=args.with_replace) test_loader = DataLoader(test_set, batch_size=1, shuffle=True, pin_memory=True, num_workers=args.workers, collate_fn=test_set.collate_batch)