Skip to content

Commit

Permalink
initial release
Browse files Browse the repository at this point in the history
  • Loading branch information
wangg12 committed May 8, 2021
1 parent d272cb8 commit f35b0d5
Show file tree
Hide file tree
Showing 197 changed files with 152,067 additions and 4 deletions.
50 changes: 50 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
pretrained_models
.DS_Store
# output dir
/data
/datasets
*.so.*
*.tar.gz
*.egg-info*
/output
instant_test_output
inference_test_output


*.ttf
*.jpg
*.png
*.txt

# compilation and distribution
__pycache__
_ext
*.pyc
*.so
detectron2.egg-info/
build/
dist/

# pytorch/python/numpy formats
*.pth
*.pkl
*.npy

# ipython/jupyter notebooks
*.ipynb
**/.ipynb_checkpoints/

# Editor temporaries
*.swn
*.swo
*.swp
*~

# Pycharm editor settings
.idea

# VSCode editor settings
.vscode

# project dirs
/models
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright 2020- Gu Wang
Copyright 2020- Gu Wang

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
70 changes: 67 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,76 @@
This repo provides the PyTorch implementation of the work:

**Gu Wang, Fabian Manhardt, Federico Tombari, Xiangyang Ji. GDR-Net: Geometry-Guided Direct Regression Network for Monocular 6D Object Pose Estimation. In CVPR 2021.**
[[ArXiv]](http://arxiv.org/abs/2102.12145)
[[ArXiv](http://arxiv.org/abs/2102.12145)][[bibtex](#Citation)]

Code will be coming soon.


## Citing
## Overview
<p align="center">
<img src='assets/gdrn_architecture.png' width='800'>
<p>



## Requirements
* Ubuntu 16.04/18.04, CUDA 10.1/10.2, python >= 3.6, PyTorch >= 1.6, torchvision
* Install `detectron2` from [source](https://github.com/facebookresearch/detectron2)
* `sh scripts/install_deps.sh`
* Compile the cpp extension for `farthest points sampling (fps)`:
```
sh core/csrc/compile.sh
```
## Datasets
Download the 6D pose datasets (LM, LM-O, YCB-V) from the
[BOP website](https://bop.felk.cvut.cz/datasets/) and
[VOC 2012](https://pjreddie.com/projects/pascal-voc-dataset-mirror/)
for background images.
Please also download the `image_sets` and `test_bboxes` from
here ([BaiduNetDisk](https://pan.baidu.com/s/1gGoZGkuMYxhU9LBKxuSz0g), password: qjfk).
The structure of `datasets` folder should look like below:
```
# recommend using soft links (ln -sf)
datasets/
├── BOP_DATASETS
├──lm
├──lmo
├──ycbv
├── lm_imgn # the OpenGL rendered images for LM, 1k/obj
├── lm_renders_blender # the Blender rendered images for LM, 10k/obj (pvnet-rendering)
├── VOCdevkit
```
* `lm_imgn` comes from [DeepIM](https://github.com/liyi14/mx-DeepIM), which can be downloaded here ([BaiduNetDisk](https://pan.baidu.com/s/1e9SJoqb0EmyqVLEVlbNQIA), password: vr0i).
* `lm_renders_blender` comes from [pvnet-rendering](https://github.com/zju3dv/pvnet-rendering), note that we do not need the fused data.
## Training GDR-Net
`./core/gdrn_modeling/train_gdrn.sh <config_path> <gpu_ids> (other args)`
Example:
```
./core/gdrn_modeling/train_gdrn.sh configs/gdrn/lm/a6_cPnP_lm13.py 0 # multiple gpus: 0,1,2,3
# add --resume if you want to resume from an interrupted experiment.
```
Our trained GDR-Net models can be found here ([BaiduNetDisk](https://pan.baidu.com/s/1_MEZJBd67hdxcE8JzmnOtA), password: kedv). <br />
<sub><sup>(Note that the models for BOP setup in the supplement were trained using a refactored version of this repo (not compatible), they are slightly better than the models provided here.)</sup></sub>
## Evaluation
`./core/gdrn_modeling/test_gdrn.sh <config_path> <gpu_ids> <ckpt_path> (other args)`
Example:
```
./core/gdrn_modeling/test_gdrn.sh configs/gdrn/lmo/a6_cPnP_AugAAETrunc_BG0.5_lmo_real_pbr0.1_40e.py 0 output/gdrn/lmo/a6_cPnP_AugAAETrunc_BG0.5_lmo_real_pbr0.1_40e/gdrn_lmo_real_pbr.pth
```
## Citation
If you find this useful in your research, please consider citing:
```
@InProceedings{Wang_2021_GDRN,
Expand Down
Binary file added assets/gdrn_architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
174 changes: 174 additions & 0 deletions configs/_base_/common_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
OUTPUT_ROOT = "output"
# if OUTPUT_DIR="auto", osp.join(cfg.OUTPUT_ROOT, osp.splitext(args.config_file)[0].split("configs/")[1])
OUTPUT_DIR = "output"

EXP_NAME = ""

DEBUG = False
# Set seed to negative to fully randomize everything.
# Set seed to positive to use a fixed seed. Note that a fixed seed does not
# guarantee fully deterministic behavior.
SEED = -1
# Benchmark different cudnn algorithms.
# If input images have very different sizes, this option will have large overhead
# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
# If input images have the same or similar sizes, benchmark is often helpful.
CUDNN_BENCHMARK = True
# The period (in terms of steps) for minibatch visualization at train time.
# Set to 0 to disable.
VIS_PERIOD = 0

# -----------------------------------------------------------------------------
# Input
# -----------------------------------------------------------------------------
INPUT = dict(
# Whether the model needs RGB, YUV, HSV etc.
FORMAT="BGR",
MIN_SIZE_TRAIN=(480,),
MAX_SIZE_TRAIN=640,
MIN_SIZE_TRAIN_SAMPLING="choice",
MIN_SIZE_TEST=480,
MAX_SIZE_TEST=640,
WITH_DEPTH=False,
AUG_DEPTH=False,
# color aug
COLOR_AUG_PROB=0.0,
COLOR_AUG_TYPE="ROI10D",
COLOR_AUG_CODE="",
COLOR_AUG_SYN_ONLY=False,
## bg images
BG_TYPE="VOC_table", # VOC_table | coco | VOC | SUN2012
BG_IMGS_ROOT="datasets/VOCdevkit/VOC2012/", # "datasets/coco/train2017/"
NUM_BG_IMGS=10000,
CHANGE_BG_PROB=0.5, # prob to change bg of real image
# truncation fg (randomly replace some side of fg with bg during replace_bg)
TRUNCATE_FG=False,
BG_KEEP_ASPECT_RATIO=True,
## bbox aug
DZI_TYPE="uniform", # uniform, truncnorm, none, roi10d
DZI_PAD_SCALE=1.0,
DZI_SCALE_RATIO=0.25, # wh scale
DZI_SHIFT_RATIO=0.25, # center shift
# smooth xyz map by median filter
SMOOTH_XYZ=False,
)

# -----------------------------------------------------------------------------
# Datasets
# -------------------------------------------------------------------------
DATASETS = dict(
TRAIN=(),
TRAIN2=(), # the second training dataset, useful for data balancing
TRAIN2_RATIO=0.0,
# List of the pre-computed proposal files for training, which must be consistent
# with datasets listed in DATASETS.TRAIN.
PROPOSAL_FILES_TRAIN=(),
# Number of top scoring precomputed proposals to keep for training
PRECOMPUTED_PROPOSAL_TOPK_TRAIN=2000,
TEST=(),
PROPOSAL_FILES_TEST=(),
# Number of top scoring precomputed proposals to keep for test
PRECOMPUTED_PROPOSAL_TOPK_TEST=1000,
DET_FILES_TEST=(),
DET_TOPK_PER_OBJ=1,
DET_THR=0.0, # filter detections
# NOTE: override if symmetric objects are different, used for custom evaluator
# SYM_OBJS=["024_bowl", "036_wood_block", "051_large_clamp", "052_extra_large_clamp", "061_foam_brick"], # ycbv
# SYM_OBJS=["002_master_chef_can", "024_bowl", "025_mug", "036_wood_block", "040_large_marker", "051_large_clamp",
# "052_extra_large_clamp", "061_foam_brick"], # ycbv_bop
SYM_OBJS=["bowl", "cup", "eggbox", "glue"],
)

# -----------------------------------------------------------------------------
# DataLoader
# -----------------------------------------------------------------------------
DATALOADER = dict(
# Number of data loading threads
NUM_WORKERS=4,
ASPECT_RATIO_GROUPING=False, # default True in detectron2
# Default sampler for dataloader
# Options: TrainingSampler, RepeatFactorTrainingSampler
SAMPLER_TRAIN="TrainingSampler",
# Repeat threshold for RepeatFactorTrainingSampler
REPEAT_THRESHOLD=0.0,
# If True, the dataloader will filter out images that have no associated
# annotations at train time.
FILTER_EMPTY_ANNOTATIONS=True,
# NOTE: set to False if you want to see the image anyways
FILTER_EMPTY_DETS=True, # filter images with empty detections
# filter out instances with visib_fract <= visib_thr at train time
FILTER_VISIB_THR=0.0,
)

# ---------------------------------------------------------------------------- #
# Solver
# ---------------------------------------------------------------------------- #
SOLVER = dict(
IMS_PER_BATCH=6,
TOTAL_EPOCHS=160,
# NOTE: use string code to get cfg dict like mmdet
# will ignore OPTIMIZER_NAME, BASE_LR, MOMENTUM, WEIGHT_DECAY
OPTIMIZER_CFG=dict(type="RMSprop", lr=1e-4, momentum=0.0, weight_decay=0),
#######
GAMMA=0.1,
BIAS_LR_FACTOR=1.0,
LR_SCHEDULER_NAME="WarmupMultiStepLR", # WarmupMultiStepLR | flat_and_anneal
WARMUP_METHOD="linear",
WARMUP_FACTOR=1.0 / 1000,
WARMUP_ITERS=1000,
ANNEAL_METHOD="step",
ANNEAL_POINT=0.75,
POLY_POWER=0.9, # poly power
REL_STEPS=(0.5, 0.75),
# checkpoint
CHECKPOINT_PERIOD=5,
CHECKPOINT_BY_EPOCH=True,
MAX_TO_KEEP=5,
# Enable automatic mixed precision for training
# Note that this does not change model's inference behavior.
# To use AMP in inference, run inference under autocast()
AMP=dict(ENABLED=False),
)

# ---------------------------------------------------------------------------- #
# Specific train options
# ---------------------------------------------------------------------------- #
TRAIN = dict(
PRINT_FREQ=100,
VERBOSE=False,
VIS=False,
# vis imgs in tensorboard
VIS_IMG=False,
)
# ---------------------------------------------------------------------------- #
# Specific val options
# ---------------------------------------------------------------------------- #
VAL = dict(
DATASET_NAME="lm",
SCRIPT_PATH="lib/pysixd/scripts/eval_pose_results_more.py",
RESULTS_PATH="",
TARGETS_FILENAME="lm_test_targets_bb8.json",
ERROR_TYPES="ad,rete,re,te,proj",
RENDERER_TYPE="cpp", # cpp, python, egl, aae
SPLIT="test",
SPLIT_TYPE="bb8",
N_TOP=1, # SISO: 1, VIVO: -1 (for LINEMOD, 1/-1 are the same)
EVAL_CACHED=False, # if the predicted poses have been saved
SCORE_ONLY=False, # if the errors have been calculated
EVAL_PRINT_ONLY=False, # if the scores/recalls have been saved
EVAL_PRECISION=False, # use precision or recall
USE_BOP=False, # whether to use bop toolkit
)

# ---------------------------------------------------------------------------- #
# Specific test options
# ---------------------------------------------------------------------------- #
TEST = dict(
EVAL_PERIOD=0,
VIS=False,
TEST_BBOX_TYPE="gt", # gt | est
# USE_PNP = False, # use pnp or direct prediction
# PNP_TYPE = "ransac_pnp",
PRECISE_BN=dict(ENABLED=False, NUM_ITER=200),
AMP_TEST=False,
)
Loading

0 comments on commit f35b0d5

Please sign in to comment.