diff --git a/.gitignore b/.gitignore index 948a72d..b9a950b 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,6 @@ dataset/scannetv2/test dataset/scannetv2/val_gt +scans/* +test.ipynb +model/pointgroup/pointgroup.pth diff --git a/README.md b/README.md index 4b95358..3326369 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,29 @@ # PointGroup + ## PointGroup: Dual-Set Point Grouping for 3D Instance Segmentation (CVPR2020) + ![overview](https://github.com/llijiang/PointGroup/blob/master/doc/overview.png) Code for the paper **PointGroup:Dual-Set Point Grouping for 3D Instance Segmentation**, CVPR 2020 (Oral). -**Authors**: Li Jiang, Hengshuang Zhao, Shaoshuai Shi, Shu Liu, Chi-Wing Fu, Jiaya Jia +**Authors**: Li Jiang, Hengshuang Zhao, Shaoshuai Shi, Shu Liu, Chi-Wing Fu, Jiaya Jia [[arxiv]](https://arxiv.org/abs/2004.01658) [[video]](https://youtu.be/HMetye3gmAs) ## Introduction + Instance segmentation is an important task for scene understanding. Compared to the fully-developed 2D, 3D instance segmentation for point clouds have much room to improve. In this paper, we present PointGroup, a new end-to-end bottom-up architecture, specifically focused on better grouping the points by exploring the void space between objects. We design a two-branch network to extract point features and predict semantic labels and offsets, for shifting each point towards its respective instance centroid. A clustering component is followed to utilize both the original and offset-shifted point coordinate sets, taking advantage of their complementary strength. Further, we formulate the ScoreNet to evaluate the candidate instances, followed by the Non-Maximum Suppression (NMS) to remove duplicates. ## Installation ### Requirements -* Python 3.7.0 -* Pytorch 1.1.0 -* CUDA 9.0 + +- Python 3.7.0 +- Pytorch 1.1.0 +- CUDA 9.0 ### Virtual Environment + ``` conda create -n pointgroup python==3.7 source activate pointgroup @@ -27,66 +32,90 @@ source activate pointgroup ### Install `PointGroup` (1) Clone the PointGroup repository. + ``` -git clone https://github.com/llijiang/PointGroup.git --recursive +git clone https://github.com/llijiang/PointGroup.git --recursive cd PointGroup ``` (2) Install the dependent libraries. + ``` pip install -r requirements.txt -conda install -c bioconda google-sparsehash +conda install -c bioconda google-sparsehash ``` -(3) For the SparseConv, we apply the implementation of [spconv](https://github.com/traveller59/spconv). The repository is recursively downloaded at step (1). We use the version 1.0 of spconv. +(3) For the SparseConv, we apply the implementation of [spconv](https://github.com/traveller59/spconv). The repository is recursively downloaded at step (1). We use the version 1.0 of spconv. **Note:** We further modify `spconv\spconv\functional.py` to make `grad_output` contiguous. Make sure you use our modified `spconv`. -* To compile `spconv`, firstly install the dependent libraries. -``` +- To compile `spconv`, firstly install the dependent libraries. + +```bash conda install libboost conda install -c daleydeng gcc-5 # need gcc-5.4 for sparseconv +sudo apt-get update && sudo apt-get install -y libboost-all-dev libsparsehash-dev ``` + Add the `$INCLUDE_PATH$` that contains `boost` in `lib/spconv/CMakeLists.txt`. (Not necessary if it could be found.) + ``` include_directories($INCLUDE_PATH$) ``` -* Compile the `spconv` library. -``` +- Compile the `spconv` library. + +1. go to `lib/spconv/src/spconv/all.cc`, run +2. On line number 20 replace `torch::jit::RegisterOperators` with `torch::RegisterOperators` +3. Or install + + 1. ```bash + pip install spconv-1.0-cp39-cp39-linux_x86_64.whl + ``` + +```bash cd lib/spconv python setup.py bdist_wheel ``` -* Run `cd dist` and use pip to install the generated `.whl` file. - - +- Run `cd dist` and use pip to install the generated `.whl` file. (4) Compile the `pointgroup_ops` library. + ``` cd lib/pointgroup_ops python setup.py develop ``` -If any header files could not be found, run the following commands. + +If any header files could not be found, run the following commands. + ``` python setup.py build_ext --include-dirs=$INCLUDE_PATH$ python setup.py develop ``` -`$INCLUDE_PATH$` is the path to the folder containing the header files that could not be found. +`$INCLUDE_PATH$` is the path to the folder containing the header files that could not be found. ## Data Preparation -(1) Download the [ScanNet](http://www.scan-net.org/) v2 dataset. +(1) Run script to download the ScanNet v2 dataset. + +```python +python download_scannetv2.py --0 +``` + +(2) Download the [ScanNet](http://www.scan-net.org/) v2 dataset. -(2) Put the data in the corresponding folders. -* Copy the files `[scene_id]_vh_clean_2.ply`, `[scene_id]_vh_clean_2.labels.ply`, `[scene_id]_vh_clean_2.0.010000.segs.json` and `[scene_id].aggregation.json` into the `dataset/scannetv2/train` and `dataset/scannetv2/val` folders according to the ScanNet v2 train/val [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark). +(4) Put the data in the corresponding folders. -* Copy the files `[scene_id]_vh_clean_2.ply` into the `dataset/scannetv2/test` folder according to the ScanNet v2 test [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark). +- Copy the files `[scene_id]_vh_clean_2.ply`, `[scene_id]_vh_clean_2.labels.ply`, `[scene_id]_vh_clean_2.0.010000.segs.json` and `[scene_id].aggregation.json` into the `dataset/scannetv2/train` and `dataset/scannetv2/val` folders according to the ScanNet v2 train/val [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark). -* Put the file `scannetv2-labels.combined.tsv` in the `dataset/scannetv2` folder. +- Copy the files `[scene_id]_vh_clean_2.ply` into the `dataset/scannetv2/test` folder according to the ScanNet v2 test [split](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark). + +- Put the file `scannetv2-labels.combined.tsv` in the `dataset/scannetv2` folder. The dataset files are organized as follows. + ``` PointGroup ├── dataset @@ -96,11 +125,12 @@ PointGroup │ │ ├── val │ │ │ ├── [scene_id]_vh_clean_2.ply & [scene_id]_vh_clean_2.labels.ply & [scene_id]_vh_clean_2.0.010000.segs.json & [scene_id].aggregation.json │ │ ├── test -│ │ │ ├── [scene_id]_vh_clean_2.ply +│ │ │ ├── [scene_id]_vh_clean_2.ply │ │ ├── scannetv2-labels.combined.tsv ``` (3) Generate input files `[scene_id]_inst_nostuff.pth` for instance segmentation. + ``` cd dataset/scannetv2 python prepare_data_inst.py --data_split train @@ -109,10 +139,13 @@ python prepare_data_inst.py --data_split test ``` ## Training + ``` -CUDA_VISIBLE_DEVICES=0 python train.py --config config/pointgroup_run1_scannet.yaml +CUDA_VISIBLE_DEVICES=0 python train.py --config config/pointgroup_run1_scannet.yaml ``` + You can start a tensorboard session by + ``` tensorboard --logdir=./exp --port=6666 ``` @@ -120,51 +153,64 @@ tensorboard --logdir=./exp --port=6666 ## Inference and Evaluation (1) If you want to evaluate on validation set, prepare the `.txt` instance ground-truth files as the following. + ``` cd dataset/scannetv2 python prepare_data_inst_gttxt.py ``` -Make sure that you have prepared the `[scene_id]_inst_nostuff.pth` files before. -(2) Test and evaluate. +Make sure that you have prepared the `[scene_id]_inst_nostuff.pth` files before. + +(2) Test and evaluate. + +a. To evaluate on validation set, set `split` and `eval` in the config file as `val` and `True`. Then run -a. To evaluate on validation set, set `split` and `eval` in the config file as `val` and `True`. Then run ``` CUDA_VISIBLE_DEVICES=0 python test.py --config config/pointgroup_run1_scannet.yaml ``` + An alternative evaluation method is to set `save_instance` as `True`, and evaluate with the ScanNet official [evaluation script](https://github.com/ScanNet/ScanNet/blob/master/BenchmarkScripts/3d_evaluation/evaluate_semantic_instance.py). b. To run on test set, set (`split`, `eval`, `save_instance`) as (`test`, `False`, `True`). Then run + ``` CUDA_VISIBLE_DEVICES=0 python test.py --config config/pointgroup_run1_scannet.yaml ``` c. To test with a pretrained model, run + ``` CUDA_VISIBLE_DEVICES=0 python test.py --config config/pointgroup_default_scannet.yaml --pretrain $PATH_TO_PRETRAIN_MODEL$ ``` ## Pretrained Model -We provide a pretrained model trained on ScanNet v2 dataset. Download it [here](https://drive.google.com/file/d/1wGolvj73i-vNtvsHhg_KXonNH2eB_6-w/view?usp=sharing). Its performance on ScanNet v2 validation set is 35.2/57.1/71.4 in terms of mAP/mAP50/mAP25. +We provide a pretrained model trained on ScanNet v2 dataset. Download it [here](https://drive.google.com/file/d/1wGolvj73i-vNtvsHhg_KXonNH2eB_6-w/view?usp=sharing). Its performance on ScanNet v2 validation set is 35.2/57.1/71.4 in terms of mAP/mAP50/mAP25. ## Visualize + To visualize the point cloud, you should first install [mayavi](https://docs.enthought.com/mayavi/mayavi/installation.html). Then you could visualize by running + ``` -cd util +cd util python visualize.py --data_root $DATA_ROOT$ --result_root $RESULT_ROOT$ --room_name $ROOM_NAME$ --room_split $ROOM_SPLIT$ --task $TASK$ ``` + The visualization task could be `input`, `instance_gt`, `instance_pred`, `semantic_pred` and `semantic_gt`. -## Results on ScanNet Benchmark +## Results on ScanNet Benchmark + Quantitative results on ScanNet test set at the submisison time. ![scannet_result](https://github.com/llijiang/PointGroup/blob/master/doc/scannet_benchmark.png) ## TODO List + - [ ] Distributed multi-GPU training ## Citation + If you find this work useful in your research, please cite: + ``` @article{jiang2020pointgroup, title={PointGroup: Dual-Set Point Grouping for 3D Instance Segmentation}, @@ -175,9 +221,9 @@ If you find this work useful in your research, please cite: ``` ## Acknowledgement -This repo is built upon several repos, e.g., [SparseConvNet](https://github.com/facebookresearch/SparseConvNet), [spconv](https://github.com/traveller59/spconv) and [ScanNet](https://github.com/ScanNet/ScanNet). -## Contact -If you have any questions or suggestions about this repo, please feel free to contact me (lijiang@cse.cuhk.edu.hk). +This repo is built upon several repos, e.g., [SparseConvNet](https://github.com/facebookresearch/SparseConvNet), [spconv](https://github.com/traveller59/spconv) and [ScanNet](https://github.com/ScanNet/ScanNet). +## Contact +If you have any questions or suggestions about this repo, please feel free to contact me (lijiang@cse.cuhk.edu.hk). diff --git a/download_scannetv2.py b/download_scannetv2.py new file mode 100644 index 0000000..ade540c --- /dev/null +++ b/download_scannetv2.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python +# Downloads ScanNet public data release +# Run with ./download-scannet.py (or python download-scannet.py on Windows) +# -*- coding: utf-8 -*- +import argparse +import os +import urllib.request +import tempfile + +import ssl + +ssl._create_default_https_context = ssl._create_unverified_context + +BASE_URL = "http://kaldir.vc.in.tum.de/scannet/" +TOS_URL = BASE_URL + "ScanNet_TOS.pdf" +FILETYPES = [ + ".aggregation.json", + ".sens", + ".txt", + "_vh_clean.ply", + "_vh_clean_2.0.010000.segs.json", + "_vh_clean_2.ply", + "_vh_clean.segs.json", + "_vh_clean.aggregation.json", + "_vh_clean_2.labels.ply", + "_2d-instance.zip", + "_2d-instance-filt.zip", + "_2d-label.zip", + "_2d-label-filt.zip", +] +FILETYPES_TEST = [".sens", ".txt", "_vh_clean.ply", "_vh_clean_2.ply"] +PREPROCESSED_FRAMES_FILE = ["scannet_frames_25k.zip", "5.6GB"] +TEST_FRAMES_FILE = ["scannet_frames_test.zip", "610MB"] +LABEL_MAP_FILES = ["scannetv2-labels.combined.tsv", "scannet-labels.combined.tsv"] +DATA_EFFICIENT_FILES = [ + "limited-reconstruction-scenes.zip", + "limited-annotation-points.zip", + "limited-bboxes.zip", + "1.7MB", +] +GRIT_FILES = ["ScanNet-GRIT.zip"] +RELEASES = ["v2/scans", "v1/scans"] +RELEASES_TASKS = ["v2/tasks", "v1/tasks"] +RELEASES_NAMES = ["v2", "v1"] +RELEASE = RELEASES[0] +RELEASE_TASKS = RELEASES_TASKS[0] +RELEASE_NAME = RELEASES_NAMES[0] +LABEL_MAP_FILE = LABEL_MAP_FILES[0] +RELEASE_SIZE = "1.2TB" +V1_IDX = 1 + + +def get_release_scans(release_file): + scan_lines = urllib.request.urlopen(release_file) + scans = [] + for scan_line in scan_lines: + scan_id = scan_line.decode("utf8").rstrip("\n") + scans.append(scan_id) + return scans + + +def download_release(release_scans, out_dir, file_types, use_v1_sens, skip_existing): + if len(release_scans) == 0: + return + print("Downloading ScanNet " + RELEASE_NAME + " release to " + out_dir + "...") + for scan_id in release_scans: + scan_out_dir = os.path.join(out_dir, scan_id) + download_scan(scan_id, scan_out_dir, file_types, use_v1_sens, skip_existing) + print("Downloaded ScanNet " + RELEASE_NAME + " release.") + + +def download_file(url, out_file): + out_dir = os.path.dirname(out_file) + if not os.path.isdir(out_dir): + os.makedirs(out_dir) + if not os.path.isfile(out_file): + print("\t" + url + " > " + out_file) + fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) + f = os.fdopen(fh, "w") + f.close() + urllib.request.urlretrieve(url, out_file_tmp) + os.rename(out_file_tmp, out_file) + else: + print("WARNING: skipping download of existing file " + out_file) + + +def download_scan(scan_id, out_dir, file_types, use_v1_sens, skip_existing=False): + print("Downloading ScanNet " + RELEASE_NAME + " scan " + scan_id + " ...") + if not os.path.isdir(out_dir): + os.makedirs(out_dir) + for ft in file_types: + v1_sens = use_v1_sens and ft == ".sens" + url = ( + BASE_URL + RELEASE + "/" + scan_id + "/" + scan_id + ft + if not v1_sens + else BASE_URL + RELEASES[V1_IDX] + "/" + scan_id + "/" + scan_id + ft + ) + out_file = out_dir + "/" + scan_id + ft + if skip_existing and os.path.isfile(out_file): + continue + download_file(url, out_file) + print("Downloaded scan " + scan_id) + + +def download_task_data(out_dir): + print("Downloading ScanNet v1 task data...") + files = [ + LABEL_MAP_FILES[V1_IDX], + "obj_classification/data.zip", + "obj_classification/trained_models.zip", + "voxel_labeling/data.zip", + "voxel_labeling/trained_models.zip", + ] + for file in files: + url = BASE_URL + RELEASES_TASKS[V1_IDX] + "/" + file + localpath = os.path.join(out_dir, file) + localdir = os.path.dirname(localpath) + if not os.path.isdir(localdir): + os.makedirs(localdir) + download_file(url, localpath) + print("Downloaded task data.") + + +def download_tfrecords(in_dir, out_dir): + print("Downloading tf records (302 GB)...") + if not os.path.exists(out_dir): + os.makedirs(out_dir) + split_to_num_shards = {"train": 100, "val": 25, "test": 10} + + for folder_name in ["hires_tfrecords", "lores_tfrecords"]: + folder_dir = "%s/%s" % (in_dir, folder_name) + save_dir = "%s/%s" % (out_dir, folder_name) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + for split, num_shards in split_to_num_shards.items(): + for i in range(num_shards): + file_name = "%s-%05d-of-%05d.tfrecords" % (split, i, num_shards) + url = "%s/%s" % (folder_dir, file_name) + localpath = "%s/%s/%s" % (out_dir, folder_name, file_name) + download_file(url, localpath) + + +def download_label_map(out_dir): + print("Downloading ScanNet " + RELEASE_NAME + " label mapping file...") + files = [LABEL_MAP_FILE] + for file in files: + url = BASE_URL + RELEASE_TASKS + "/" + file + localpath = os.path.join(out_dir, file) + localdir = os.path.dirname(localpath) + if not os.path.isdir(localdir): + os.makedirs(localdir) + download_file(url, localpath) + print("Downloaded ScanNet " + RELEASE_NAME + " label mapping file.") + + +def main(): + parser = argparse.ArgumentParser( + description="Downloads ScanNet public data release." + ) + parser.add_argument( + "-o", "--out_dir", required=True, help="directory in which to download" + ) + parser.add_argument( + "--task_data", action="store_true", help="download task data (v1)" + ) + parser.add_argument( + "--label_map", action="store_true", help="download label map file" + ) + parser.add_argument( + "--v1", action="store_true", help="download ScanNet v1 instead of v2" + ) + parser.add_argument("--id", help="specific scan id to download") + parser.add_argument( + "--preprocessed_frames", + action="store_true", + help="download preprocessed subset of ScanNet frames (" + + PREPROCESSED_FRAMES_FILE[1] + + ")", + ) + parser.add_argument( + "--test_frames_2d", + action="store_true", + help="download 2D test frames (" + + TEST_FRAMES_FILE[1] + + "; also included with whole dataset download)", + ) + parser.add_argument( + "--data_efficient", + action="store_true", + help="download data efficient task files; also included with whole dataset download)", + ) + parser.add_argument( + "--tf_semantic", + action="store_true", + help="download google tensorflow records for 3D segmentation / detection", + ) + parser.add_argument( + "--grit", + action="store_true", + help="download ScanNet files for General Robust Image Task", + ) + parser.add_argument( + "--type", + help="specific file type to download (.aggregation.json, .sens, .txt, _vh_clean.ply, _vh_clean_2.0.010000.segs.json, _vh_clean_2.ply, _vh_clean.segs.json, _vh_clean.aggregation.json, _vh_clean_2.labels.ply, _2d-instance.zip, _2d-instance-filt.zip, _2d-label.zip, _2d-label-filt.zip)", + ) + parser.add_argument( + "--skip_existing", + action="store_true", + help="skip download of existing files when downloading full release", + ) + args = parser.parse_args() + + print( + "By pressing any key to continue you confirm that you have agreed to the ScanNet terms of use as described at:" + ) + print(TOS_URL) + print("***") + print("Press any key to continue, or CTRL-C to exit.") + key = input("") + + if args.v1: + global RELEASE + global RELEASE_TASKS + global RELEASE_NAME + global LABEL_MAP_FILE + RELEASE = RELEASES[V1_IDX] + RELEASE_TASKS = RELEASES_TASKS[V1_IDX] + RELEASE_NAME = RELEASES_NAMES[V1_IDX] + LABEL_MAP_FILE = LABEL_MAP_FILES[V1_IDX] + assert (not args.tf_semantic) and ( + not args.grit + ), "Task files specified invalid for v1" + + release_file = BASE_URL + RELEASE + ".txt" + release_scans = get_release_scans(release_file) + file_types = FILETYPES + release_test_file = BASE_URL + RELEASE + "_test.txt" + release_test_scans = get_release_scans(release_test_file) + file_types_test = FILETYPES_TEST + out_dir_scans = os.path.join(args.out_dir, "scans") + out_dir_test_scans = os.path.join(args.out_dir, "scans_test") + out_dir_tasks = os.path.join(args.out_dir, "tasks") + + if args.type: # download file type + file_type = args.type + if file_type not in FILETYPES: + print("ERROR: Invalid file type: " + file_type) + return + file_types = [file_type] + if file_type in FILETYPES_TEST: + file_types_test = [file_type] + else: + file_types_test = [] + if args.task_data: # download task data + download_task_data(out_dir_tasks) + elif args.label_map: # download label map file + download_label_map(args.out_dir) + elif args.preprocessed_frames: # download preprocessed scannet_frames_25k.zip file + if args.v1: + print("ERROR: Preprocessed frames only available for ScanNet v2") + print( + "You are downloading the preprocessed subset of frames " + + PREPROCESSED_FRAMES_FILE[0] + + " which requires " + + PREPROCESSED_FRAMES_FILE[1] + + " of space." + ) + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, PREPROCESSED_FRAMES_FILE[0]), + os.path.join(out_dir_tasks, PREPROCESSED_FRAMES_FILE[0]), + ) + elif args.test_frames_2d: # download test scannet_frames_test.zip file + if args.v1: + print("ERROR: 2D test frames only available for ScanNet v2") + print( + "You are downloading the 2D test set " + + TEST_FRAMES_FILE[0] + + " which requires " + + TEST_FRAMES_FILE[1] + + " of space." + ) + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, TEST_FRAMES_FILE[0]), + os.path.join(out_dir_tasks, TEST_FRAMES_FILE[0]), + ) + elif args.data_efficient: # download data efficient task files + print( + "You are downloading the data efficient task files" + + " which requires " + + DATA_EFFICIENT_FILES[-1] + + " of space." + ) + for k in range(len(DATA_EFFICIENT_FILES) - 1): + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, DATA_EFFICIENT_FILES[k]), + os.path.join(out_dir_tasks, DATA_EFFICIENT_FILES[k]), + ) + elif args.tf_semantic: # download google tf records + download_tfrecords( + os.path.join(BASE_URL, RELEASE_TASKS, "tf3d"), + os.path.join(out_dir_tasks, "tf3d"), + ) + elif args.grit: # download GRIT file + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, GRIT_FILES[0]), + os.path.join(out_dir_tasks, GRIT_FILES[0]), + ) + elif args.id: # download single scan + scan_id = args.id + is_test_scan = scan_id in release_test_scans + if scan_id not in release_scans and (not is_test_scan or args.v1): + print("ERROR: Invalid scan id: " + scan_id) + else: + out_dir = ( + os.path.join(out_dir_scans, scan_id) + if not is_test_scan + else os.path.join(out_dir_test_scans, scan_id) + ) + scan_file_types = file_types if not is_test_scan else file_types_test + use_v1_sens = not is_test_scan + if not is_test_scan and not args.v1 and ".sens" in scan_file_types: + print( + "Note: ScanNet v2 uses the same .sens files as ScanNet v1: Press 'n' to exclude downloading .sens files for each scan" + ) + key = input("") + if key.strip().lower() == "n": + scan_file_types.remove(".sens") + download_scan( + scan_id, + out_dir, + scan_file_types, + use_v1_sens, + skip_existing=args.skip_existing, + ) + else: # download entire release + if len(file_types) == len(FILETYPES): + print( + "WARNING: You are downloading the entire ScanNet " + + RELEASE_NAME + + " release which requires " + + RELEASE_SIZE + + " of space." + ) + else: + print( + "WARNING: You are downloading all ScanNet " + + RELEASE_NAME + + " scans of type " + + file_types[0] + ) + print( + "Note that existing scan directories will be skipped. Delete partially downloaded directories to re-download." + ) + print("***") + print("Press any key to continue, or CTRL-C to exit.") + key = input("") + if not args.v1 and ".sens" in file_types: + print( + "Note: ScanNet v2 uses the same .sens files as ScanNet v1: Press 'n' to exclude downloading .sens files for each scan" + ) + key = input("") + if key.strip().lower() == "n": + file_types.remove(".sens") + download_release( + release_scans, + out_dir_scans, + file_types, + use_v1_sens=True, + skip_existing=args.skip_existing, + ) + if not args.v1: + download_label_map(args.out_dir) + download_release( + release_test_scans, + out_dir_test_scans, + file_types_test, + use_v1_sens=False, + skip_existing=args.skip_existing, + ) + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, TEST_FRAMES_FILE[0]), + os.path.join(out_dir_tasks, TEST_FRAMES_FILE[0]), + ) + for k in range(len(DATA_EFFICIENT_FILES) - 1): + download_file( + os.path.join(BASE_URL, RELEASE_TASKS, DATA_EFFICIENT_FILES[k]), + os.path.join(out_dir_tasks, DATA_EFFICIENT_FILES[k]), + ) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index ef6883e..3f96467 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -torch==1.1 +torch==1.7.1 cmake>=3.13.2 plyfile tensorboardX -pyyaml +pyyaml==5.4.1 scipy diff --git a/spconv-1.0-cp39-cp39-linux_x86_64.whl b/spconv-1.0-cp39-cp39-linux_x86_64.whl new file mode 100644 index 0000000..b4bab0f Binary files /dev/null and b/spconv-1.0-cp39-cp39-linux_x86_64.whl differ diff --git a/util/config.py b/util/config.py index ca2b7c6..908e21f 100644 --- a/util/config.py +++ b/util/config.py @@ -1,22 +1,34 @@ -''' +""" config.py Written by Li Jiang -''' +""" import argparse import yaml import os + def get_parser(): - parser = argparse.ArgumentParser(description='Point Cloud Segmentation') - parser.add_argument('--config', type=str, default='config/pointgroup_default_scannet.yaml', help='path to config file') + parser = argparse.ArgumentParser(description="Point Cloud Segmentation") + parser.add_argument( + "--config", + type=str, + default="config/pointgroup_default_scannet.yaml", + help="path to config file", + ) ### pretrain - parser.add_argument('--pretrain', type=str, default='', help='path to pretrain model') + parser.add_argument( + "--pretrain", + type=str, + default="/home/sohaib/Downloads/pointgroup.pth", + help="path to pretrain model", + ) args_cfg = parser.parse_args() assert args_cfg.config is not None - with open(args_cfg.config, 'r') as f: + with open(args_cfg.config, "r") as f: + print("Config file: ", args_cfg.config) config = yaml.load(f) for key in config: for k, v in config[key].items(): @@ -26,4 +38,8 @@ def get_parser(): cfg = get_parser() -setattr(cfg, 'exp_path', os.path.join('exp', cfg.dataset, cfg.model_name, cfg.config.split('/')[-1][:-5])) +setattr( + cfg, + "exp_path", + os.path.join("exp", cfg.dataset, cfg.model_name, cfg.config.split("/")[-1][:-5]), +)