From 0b75086bbc2eca710b5424cba3df47878995bd2c Mon Sep 17 00:00:00 2001 From: Hang Zhang Date: Wed, 27 Jan 2021 18:56:56 -0800 Subject: [PATCH] Add detectron2 wrapper (#132) * add detectron2 wrapper --- README.md | 2 +- d2/README.md | 213 +++++ d2/configs/Base-RCNN-FPN.yaml | 42 + ..._rcnn_R_101_FPN_syncbn_range-scale_1x.yaml | 30 + ...e_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml | 30 + ...ResNeSt_101_FPN_syncbn_range-scale_1x.yaml | 34 + ...ResNeSt_200_FPN_syncbn_range-scale_1x.yaml | 34 + ..._ResNeSt_50_FPN_syncbn_range-scale-1x.yaml | 34 + ..._rcnn_R_101_FPN_syncbn_range-scale_1x.yaml | 30 + ...r_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml | 25 + ...ResNeSt_101_FPN_syncbn_range-scale_1x.yaml | 29 + ...NeSt_50_FPN_dcn_syncbn_range-scale_1x.yaml | 38 + ..._ResNeSt_50_FPN_syncbn_range-scale_1x.yaml | 35 + ...mask_cascade_rcnn_R_101_FPN_syncbn_1x.yaml | 23 + .../mask_cascade_rcnn_R_50_FPN_syncbn_1x.yaml | 23 + ...ascade_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml | 35 + ...NeSt_200_FPN_dcn_syncBN_all_tricks_3x.yaml | 46 ++ ..._ResNeSt_200_FPN_syncBN_all_tricks_3x.yaml | 47 ++ ...cascade_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml | 37 + .../mask_rcnn_R_101_FPN_syncbn_1x.yaml | 19 + .../mask_rcnn_R_50_FPN_syncbn_1x.yaml | 19 + .../mask_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml | 28 + .../mask_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml | 28 + .../ResNeSt-Base-Panoptic-FPN.yaml | 9 + ...ptic_ResNeSt_200_FPN_syncBN_tricks_3x.yaml | 42 + d2/configs/ResNest-Base-RCNN-FPN.yaml | 4 + d2/datasets/prepare_coco.py | 64 ++ d2/train_net.py | 170 ++++ resnest/d2/__init__.py | 2 + resnest/d2/config.py | 20 + resnest/d2/resnest.py | 734 ++++++++++++++++++ resnest/d2/splat.py | 179 +++++ 32 files changed, 2104 insertions(+), 1 deletion(-) create mode 100644 d2/README.md create mode 100644 d2/configs/Base-RCNN-FPN.yaml create mode 100644 d2/configs/COCO-Detection/faster_cascade_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_cascade_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_200_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_50_FPN_syncbn_range-scale-1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_dcn_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_syncbn_range-scale_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_101_FPN_syncbn_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_50_FPN_syncbn_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_dcn_syncBN_all_tricks_3x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_syncBN_all_tricks_3x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_syncbn_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_syncbn_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml create mode 100644 d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml create mode 100644 d2/configs/COCO-PanopticSegmentation/ResNeSt-Base-Panoptic-FPN.yaml create mode 100644 d2/configs/COCO-PanopticSegmentation/panoptic_ResNeSt_200_FPN_syncBN_tricks_3x.yaml create mode 100644 d2/configs/ResNest-Base-RCNN-FPN.yaml create mode 100644 d2/datasets/prepare_coco.py create mode 100644 d2/train_net.py create mode 100644 resnest/d2/__init__.py create mode 100644 resnest/d2/config.py create mode 100644 resnest/d2/resnest.py create mode 100644 resnest/d2/splat.py diff --git a/README.md b/README.md index 092e733..8ad7360 100644 --- a/README.md +++ b/README.md @@ -351,7 +351,7 @@ python verify.py --model resnest50 --crop-size 224 For object detection and instance segmentation models, please visit our [detectron2-ResNeSt fork](https://github.com/zhanghang1989/detectron2-ResNeSt). ### Semantic Segmentation - + - Training with PyTorch: [Encoding Toolkit](https://hangzhang.org/PyTorch-Encoding/model_zoo/segmentation.html). - Training with MXNet: [GluonCV Toolkit](https://gluon-cv.mxnet.io/model_zoo/segmentation.html#ade20k-dataset). diff --git a/d2/README.md b/d2/README.md new file mode 100644 index 0000000..971106e --- /dev/null +++ b/d2/README.md @@ -0,0 +1,213 @@ +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=resnest-split-attention-networks) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=resnest-split-attention-networks) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/panoptic-segmentation-on-coco-panoptic)](https://paperswithcode.com/sota/panoptic-segmentation-on-coco-panoptic?p=resnest-split-attention-networks) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/instance-segmentation-on-coco-minival)](https://paperswithcode.com/sota/instance-segmentation-on-coco-minival?p=resnest-split-attention-networks) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/resnest-split-attention-networks/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=resnest-split-attention-networks) + +# ResNeSt (Detectron2 Wrapper) + +Code for detection and instance segmentation experiments in [ResNeSt](https://hangzhang.org/files/resnest.pdf). + + +## Training and Inference +Please follow [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md) to install detectron2. + +To train a model with 8 gpus, please run +```shell +python train_net.py --num-gpus 8 --config-file your_config.yaml +``` + +For inference +```shell +python train_net.py \ + --config-file your_config.yaml + --eval-only MODEL.WEIGHTS /path/to/checkpoint_file +``` + +For the inference demo, please see [GETTING_STARTED.md](https://github.com/facebookresearch/detectron2/blob/master/GETTING_STARTED.md). + +## Pretrained Models + +### Object Detection + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodBackbonemAP%download
Faster R-CNNResNet-5039.25config | model | log
ResNet-10141.37config | model | log
ResNeSt-50 (ours)42.33config | model | log
ResNeSt-50-DCNv2 (ours)44.11config | model | log
ResNeSt-101 (ours)44.72config | model | log
Cascade R-CNNResNet-5042.52config | model | log
ResNet-10144.03config | model | log
ResNeSt-50 (ours)45.41config | model | log
ResNeSt-101 (ours)47.50config | model | log
ResNeSt-200 (ours)49.03config | model | log
+ +We train all models with FPN, SyncBN and image scale augmentation (short size of a image is pickedrandomly from 640 to 800). 1x learning rate schedule is used. All of them are reported on COCO-2017 validation dataset. + + + +### Instance Segmentation + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodBackbonebboxmaskdownload
Mask R-CNNResNet-5039.9736.05config | model | log
ResNet-10141.7837.51config | model | log
ResNeSt-50 (ours)42.8138.14config | model | log
ResNeSt-101 (ours)45.7540.65config | model | log
Cascade R-CNNResNet-5043.0637.19config | model | log
ResNet-10144.7938.52config | model | log
ResNeSt-50 (ours)46.1939.55config | model | log
ResNeSt-101 (ours)48.3041.56config | model | log
ResNeSt-200-tricks-3x (ours)50.5444.21config | model | log
ResNeSt-200-dcn-tricks-3x (ours)50.9144.50config | model | log
53.30*47.10*
+ +All models are trained along with FPN and SyncBN. For data augmentation,input images’ shorter side are randomly scaled to one of (640, 672, 704, 736, 768, 800). 1x learning rate schedule is used, if not otherwise specified. All of them are reported on COCO-2017 validation dataset. The values with * demonstrate the mutli-scale testing performance on the test-dev2019. + + + +### Panoptic Segmentation + + + + + + + + + + + + + + + +
BackbonebboxmaskPQdownload
ResNeSt-20051.0043.6847.90config | model | log
+ + +## Reference + +**ResNeSt: Split-Attention Networks** [[arXiv](https://arxiv.org/pdf/2004.08955.pdf)] + +Hang Zhang, Chongruo Wu, Zhongyue Zhang, Yi Zhu, Zhi Zhang, Haibin Lin, Yue Sun, Tong He, Jonas Muller, R. Manmatha, Mu Li and Alex Smola + +``` +@article{zhang2020resnest, +title={ResNeSt: Split-Attention Networks}, +author={Zhang, Hang and Wu, Chongruo and Zhang, Zhongyue and Zhu, Yi and Zhang, Zhi and Lin, Haibin and Sun, Yue and He, Tong and Muller, Jonas and Manmatha, R. and Li, Mu and Smola, Alexander}, +journal={arXiv preprint arXiv:2004.08955}, +year={2020} +} +``` + +### Contributors +[Chongruo Wu](https://github.com/chongruo), [Zhongyue Zhang](http://zhongyuezhang.com/), [Hang Zhang](https://hangzhang.org/) diff --git a/d2/configs/Base-RCNN-FPN.yaml b/d2/configs/Base-RCNN-FPN.yaml new file mode 100644 index 0000000..3e020f2 --- /dev/null +++ b/d2/configs/Base-RCNN-FPN.yaml @@ -0,0 +1,42 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/d2/configs/COCO-Detection/faster_cascade_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_cascade_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..659656b --- /dev/null +++ b/d2/configs/COCO-Detection/faster_cascade_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,30 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 +TEST: + PRECISE_BN: + ENABLED: True + diff --git a/d2/configs/COCO-Detection/faster_cascade_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_cascade_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..8d3c2ad --- /dev/null +++ b/d2/configs/COCO-Detection/faster_cascade_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,30 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 +TEST: + PRECISE_BN: + ENABLED: True + diff --git a/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..94c9072 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,34 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest101_detectron-486f69a8.pth" + MASK_ON: False + RESNETS: + DEPTH: 101 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_200_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_200_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..338f686 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_200_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,34 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest200_detectron-02644020.pth" + MASK_ON: False + RESNETS: + DEPTH: 200 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_50_FPN_syncbn_range-scale-1x.yaml b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_50_FPN_syncbn_range-scale-1x.yaml new file mode 100644 index 0000000..c61906f --- /dev/null +++ b/d2/configs/COCO-Detection/faster_cascade_rcnn_ResNeSt_50_FPN_syncbn_range-scale-1x.yaml @@ -0,0 +1,34 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest50_detectron-255b5649.pth" + MASK_ON: False + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-Detection/faster_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..f55c188 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_rcnn_R_101_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,30 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: False + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 +TEST: + PRECISE_BN: + ENABLED: True + + + + + + diff --git a/d2/configs/COCO-Detection/faster_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..7d5b581 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_rcnn_R_50_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,25 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: False + RESNETS: + STRIDE_IN_1X1: True + DEPTH: 50 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 +TEST: + PRECISE_BN: + ENABLED: True + + + diff --git a/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..f759107 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_101_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,29 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest101_detectron-486f69a8.pth" + MASK_ON: False + RESNETS: + DEPTH: 101 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_dcn_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_dcn_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..df53f92 --- /dev/null +++ b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_dcn_syncbn_range-scale_1x.yaml @@ -0,0 +1,38 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest50_detectron-255b5649.pth" + MASK_ON: False + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + RADIX: 2 + DEFORM_ON_PER_STAGE: [False, True, True, True] # on Res3,Res4,Res5 + DEFORM_MODULATED: True + DEFORM_NUM_GROUPS: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + + + + + + diff --git a/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_syncbn_range-scale_1x.yaml b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_syncbn_range-scale_1x.yaml new file mode 100644 index 0000000..08c48dc --- /dev/null +++ b/d2/configs/COCO-Detection/faster_rcnn_ResNeSt_50_FPN_syncbn_range-scale_1x.yaml @@ -0,0 +1,35 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest50_detectron-255b5649.pth" + MASK_ON: False + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + MIN_SIZE_TRAIN: (640, 800) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1333 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + + + + + + diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_101_FPN_syncbn_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_101_FPN_syncbn_1x.yaml new file mode 100644 index 0000000..12390be --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_101_FPN_syncbn_1x.yaml @@ -0,0 +1,23 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_50_FPN_syncbn_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_50_FPN_syncbn_1x.yaml new file mode 100644 index 0000000..142c6f6 --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_R_50_FPN_syncbn_1x.yaml @@ -0,0 +1,23 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml new file mode 100644 index 0000000..3656aeb --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml @@ -0,0 +1,35 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest101_detectron-486f69a8.pth" + MASK_ON: True + RESNETS: + DEPTH: 101 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + + diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_dcn_syncBN_all_tricks_3x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_dcn_syncBN_all_tricks_3x.yaml new file mode 100644 index 0000000..1d69288 --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_dcn_syncBN_all_tricks_3x.yaml @@ -0,0 +1,46 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest200_detectron-02644020.pth" + MASK_ON: True + RESNETS: + DEPTH: 200 + STRIDE_IN_1X1: False + RADIX: 2 + DEFORM_ON_PER_STAGE: [False, True, True, True] # on Res3,Res4,Res5 + DEFORM_MODULATED: True + DEFORM_NUM_GROUPS: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NUM_CONV: 8 + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (240000, 255000) + MAX_ITER: 270000 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + CROP: + ENABLED: True + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + AUG: + ENABLED: False diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_syncBN_all_tricks_3x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_syncBN_all_tricks_3x.yaml new file mode 100644 index 0000000..26a2eaf --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_200_FPN_syncBN_all_tricks_3x.yaml @@ -0,0 +1,47 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest200_detectron-02644020.pth" + MASK_ON: True + RESNETS: + DEPTH: 200 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NUM_CONV: 8 + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (240000, 255000) + MAX_ITER: 270000 +INPUT: + MIN_SIZE_TRAIN: (640, 864) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + CROP: + ENABLED: True + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + + + + + + diff --git a/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml new file mode 100644 index 0000000..d1423f6 --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_cascade_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml @@ -0,0 +1,37 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest50_detectron-255b5649.pth" + MASK_ON: True + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + ROI_MASK_HEAD: + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + + + + diff --git a/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_syncbn_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_syncbn_1x.yaml new file mode 100644 index 0000000..2a3cff4 --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_syncbn_1x.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" + MASK_ON: True + RESNETS: + DEPTH: 101 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_syncbn_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_syncbn_1x.yaml new file mode 100644 index 0000000..314bcbb --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_syncbn_1x.yaml @@ -0,0 +1,19 @@ +_BASE_: "../Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" + MASK_ON: True + RESNETS: + DEPTH: 50 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml new file mode 100644 index 0000000..f30488c --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_101_FPN_syncBN_1x.yaml @@ -0,0 +1,28 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest101_detectron-486f69a8.pth" + MASK_ON: True + RESNETS: + DEPTH: 101 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml new file mode 100644 index 0000000..a9112aa --- /dev/null +++ b/d2/configs/COCO-InstanceSegmentation/mask_rcnn_ResNeSt_50_FPN_syncBN_1x.yaml @@ -0,0 +1,28 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest50_detectron-255b5649.pth" + MASK_ON: True + RESNETS: + DEPTH: 50 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + ROI_MASK_HEAD: + NORM: "SyncBN" + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 +INPUT: + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True diff --git a/d2/configs/COCO-PanopticSegmentation/ResNeSt-Base-Panoptic-FPN.yaml b/d2/configs/COCO-PanopticSegmentation/ResNeSt-Base-Panoptic-FPN.yaml new file mode 100644 index 0000000..3ce4548 --- /dev/null +++ b/d2/configs/COCO-PanopticSegmentation/ResNeSt-Base-Panoptic-FPN.yaml @@ -0,0 +1,9 @@ +_BASE_: "../ResNest-Base-RCNN-FPN.yaml" +MODEL: + META_ARCHITECTURE: "PanopticFPN" + MASK_ON: True + SEM_SEG_HEAD: + LOSS_WEIGHT: 0.5 +DATASETS: + TRAIN: ("coco_2017_train_panoptic_separated",) + TEST: ("coco_2017_val_panoptic_separated",) diff --git a/d2/configs/COCO-PanopticSegmentation/panoptic_ResNeSt_200_FPN_syncBN_tricks_3x.yaml b/d2/configs/COCO-PanopticSegmentation/panoptic_ResNeSt_200_FPN_syncBN_tricks_3x.yaml new file mode 100644 index 0000000..0c33ad8 --- /dev/null +++ b/d2/configs/COCO-PanopticSegmentation/panoptic_ResNeSt_200_FPN_syncBN_tricks_3x.yaml @@ -0,0 +1,42 @@ +_BASE_: "ResNeSt-Base-Panoptic-FPN.yaml" +MODEL: + WEIGHTS: "https://s3.us-west-1.wasabisys.com/resnest/detectron/resnest200_detectron-02644020.pth" + RESNETS: + DEPTH: 200 + STRIDE_IN_1X1: False + RADIX: 2 + NORM: "SyncBN" + FPN: + NORM: "SyncBN" + ROI_HEADS: + NAME: CascadeROIHeads + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_CONV: 4 + NUM_FC: 1 + NORM: "SyncBN" + CLS_AGNOSTIC_BBOX_REG: True + SEM_SEG_HEAD: + NORM: "SyncBN" + RPN: + POST_NMS_TOPK_TRAIN: 2000 + PIXEL_MEAN: [123.68, 116.779, 103.939] + PIXEL_STD: [58.393, 57.12, 57.375] +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.02 + STEPS: (240000, 255000) + MAX_ITER: 270000 +INPUT: + MIN_SIZE_TRAIN: (400, 1000) + MIN_SIZE_TRAIN_SAMPLING: "range" + MAX_SIZE_TRAIN: 1440 + FORMAT: "RGB" +TEST: + PRECISE_BN: + ENABLED: True + AUG: + ENABLED: True + + + diff --git a/d2/configs/ResNest-Base-RCNN-FPN.yaml b/d2/configs/ResNest-Base-RCNN-FPN.yaml new file mode 100644 index 0000000..619d279 --- /dev/null +++ b/d2/configs/ResNest-Base-RCNN-FPN.yaml @@ -0,0 +1,4 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + BACKBONE: + NAME: "build_resnest_fpn_backbone" diff --git a/d2/datasets/prepare_coco.py b/d2/datasets/prepare_coco.py new file mode 100644 index 0000000..b1d15d0 --- /dev/null +++ b/d2/datasets/prepare_coco.py @@ -0,0 +1,64 @@ +"""Prepare MS COCO datasets""" +import os +import shutil +import argparse +import zipfile +from resnest.utils import download, mkdir + +_TARGET_DIR = os.path.expanduser('./coco') + +def parse_args(): + parser = argparse.ArgumentParser( + description='Initialize MS COCO dataset.', + epilog='Example: python mscoco.py --download-dir ~/mscoco', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') + args = parser.parse_args() + return args + +def download_coco(path, overwrite=False): + _DOWNLOAD_URLS = [ + ('http://images.cocodataset.org/zips/train2017.zip', + '10ad623668ab00c62c096f0ed636d6aff41faca5'), + ('http://images.cocodataset.org/zips/val2017.zip', + '4950dc9d00dbe1c933ee0170f5797584351d2a41'), + ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', + '8551ee4bb5860311e79dace7e79cb91e432e78b3'), + ('https://hangzh.s3.amazonaws.com/encoding/data/coco/train_ids.pth', + '12cd266f97c8d9ea86e15a11f11bcb5faba700b6'), + ('https://hangzh.s3.amazonaws.com/encoding/data/coco/val_ids.pth', + '4ce037ac33cbf3712fd93280a1c5e92dae3136bb'), + ] + mkdir(path) + for url, checksum in _DOWNLOAD_URLS: + filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) + # extract + if os.path.splitext(filename)[1] == '.zip': + with zipfile.ZipFile(filename) as zf: + zf.extractall(path=path) + else: + shutil.move(filename, os.path.join(path, 'annotations/'+os.path.basename(filename))) + + +def install_coco_api(): + repo_url = "https://github.com/cocodataset/cocoapi" + os.system("git clone " + repo_url) + os.system("cd cocoapi/PythonAPI/ && python setup.py install") + shutil.rmtree('cocoapi') + try: + import pycocotools + except Exception: + print("Installing COCO API failed, please install it manually %s"%(repo_url)) + + +if __name__ == '__main__': + args = parse_args() + mkdir(os.path.expanduser('~/.encoding/data')) + if args.download_dir is not None: + if os.path.isdir(_TARGET_DIR): + os.remove(_TARGET_DIR) + # make symlink + os.symlink(args.download_dir, _TARGET_DIR) + else: + download_coco(_TARGET_DIR, overwrite=False) + install_coco_api() diff --git a/d2/train_net.py b/d2/train_net.py new file mode 100644 index 0000000..6fbb7b6 --- /dev/null +++ b/d2/train_net.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Detection Training Script. + +This scripts reads a given config file and runs the training or evaluation. +It is an entry point that is made to train standard models in detectron2. + +In order to let one script support training of many models, +this script contains logic that are specific to these built-in models and therefore +may not be suitable for your own project. +For example, your research project perhaps only needs a single "evaluator". + +Therefore, we recommend you to use detectron2 as an library and take +this file as an example of how to use the library. +You may want to write your own script with your datasets and other customizations. +""" + +import logging +import os +from collections import OrderedDict +import torch + +import detectron2.utils.comm as comm +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog +from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch +from detectron2.evaluation import ( + CityscapesInstanceEvaluator, + CityscapesSemSegEvaluator, + COCOEvaluator, + COCOPanopticEvaluator, + DatasetEvaluators, + LVISEvaluator, + PascalVOCDetectionEvaluator, + SemSegEvaluator, + verify_results, +) +from detectron2.modeling import GeneralizedRCNNWithTTA +from resnest.d2 import add_resnest_config + + +class Trainer(DefaultTrainer): + """ + We use the "DefaultTrainer" which contains pre-defined default logic for + standard training workflow. They may not work for you, especially if you + are working on a new research project. In that case you can write your + own training loop. You can use "tools/plain_train_net.py" as an example. + """ + + @classmethod + def build_evaluator(cls, cfg, dataset_name, output_folder=None): + """ + Create evaluator(s) for a given dataset. + This uses the special metadata "evaluator_type" associated with each builtin dataset. + For your own dataset, you can simply create an evaluator manually in your + script and do not have to worry about the hacky if-else logic here. + """ + if output_folder is None: + output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") + evaluator_list = [] + evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type + if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: + evaluator_list.append( + SemSegEvaluator( + dataset_name, + distributed=True, + output_dir=output_folder, + ) + ) + if evaluator_type in ["coco", "coco_panoptic_seg"]: + evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) + if evaluator_type == "coco_panoptic_seg": + evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) + if evaluator_type == "cityscapes_instance": + assert ( + torch.cuda.device_count() >= comm.get_rank() + ), "CityscapesEvaluator currently do not work with multiple machines." + return CityscapesInstanceEvaluator(dataset_name) + if evaluator_type == "cityscapes_sem_seg": + assert ( + torch.cuda.device_count() >= comm.get_rank() + ), "CityscapesEvaluator currently do not work with multiple machines." + return CityscapesSemSegEvaluator(dataset_name) + elif evaluator_type == "pascal_voc": + return PascalVOCDetectionEvaluator(dataset_name) + elif evaluator_type == "lvis": + return LVISEvaluator(dataset_name, output_dir=output_folder) + if len(evaluator_list) == 0: + raise NotImplementedError( + "no Evaluator for the dataset {} with the type {}".format( + dataset_name, evaluator_type + ) + ) + elif len(evaluator_list) == 1: + return evaluator_list[0] + return DatasetEvaluators(evaluator_list) + + @classmethod + def test_with_TTA(cls, cfg, model): + logger = logging.getLogger("detectron2.trainer") + # In the end of training, run an evaluation with TTA + # Only support some R-CNN models. + logger.info("Running inference with test-time augmentation ...") + model = GeneralizedRCNNWithTTA(cfg, model) + evaluators = [ + cls.build_evaluator( + cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") + ) + for name in cfg.DATASETS.TEST + ] + res = cls.test(cfg, model, evaluators) + res = OrderedDict({k + "_TTA": v for k, v in res.items()}) + return res + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg() + add_resnest_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + cfg.freeze() + default_setup(cfg, args) + return cfg + + +def main(args): + cfg = setup(args) + + if args.eval_only: + model = Trainer.build_model(cfg) + DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( + cfg.MODEL.WEIGHTS, resume=args.resume + ) + res = Trainer.test(cfg, model) + if cfg.TEST.AUG.ENABLED: + res.update(Trainer.test_with_TTA(cfg, model)) + if comm.is_main_process(): + verify_results(cfg, res) + return res + + """ + If you'd like to do anything fancier than the standard training logic, + consider writing your own training loop (see plain_train_net.py) or + subclassing the trainer. + """ + trainer = Trainer(cfg) + trainer.resume_or_load(resume=args.resume) + if cfg.TEST.AUG.ENABLED: + trainer.register_hooks( + [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] + ) + return trainer.train() + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + print("Command Line Args:", args) + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/resnest/d2/__init__.py b/resnest/d2/__init__.py new file mode 100644 index 0000000..74ff03a --- /dev/null +++ b/resnest/d2/__init__.py @@ -0,0 +1,2 @@ +from .resnest import build_resnest_backbone, build_resnest_fpn_backbone +from .config import add_resnest_config diff --git a/resnest/d2/config.py b/resnest/d2/config.py new file mode 100644 index 0000000..471347f --- /dev/null +++ b/resnest/d2/config.py @@ -0,0 +1,20 @@ +from detectron2.config import CfgNode as CN + +def add_resnest_config(cfg): + """Add config for ResNeSt + """ + # Place the stride 2 conv on the 1x1 filter + # Use True only for the original MSRA ResNet; + # use False for C2 and Torch models + cfg.MODEL.RESNETS.STRIDE_IN_1X1 = False + # Apply deep stem + cfg.MODEL.RESNETS.DEEP_STEM = True + # Apply avg after conv2 in the BottleBlock + # When AVD=True, the STRIDE_IN_1X1 should be False + cfg.MODEL.RESNETS.AVD = True + # Apply avg_down to the downsampling layer for residual path + cfg.MODEL.RESNETS.AVG_DOWN = True + # Radix in ResNeSt + cfg.MODEL.RESNETS.RADIX = 2 + # Bottleneck_width in ResNeSt + cfg.MODEL.RESNETS.BOTTLENECK_WIDTH = 64 diff --git a/resnest/d2/resnest.py b/resnest/d2/resnest.py new file mode 100644 index 0000000..2d881c8 --- /dev/null +++ b/resnest/d2/resnest.py @@ -0,0 +1,734 @@ +import numpy as np +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + +from detectron2.layers import ( + Conv2d, + DeformConv, + FrozenBatchNorm2d, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) + +from detectron2.modeling.backbone import Backbone, FPN, BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import LastLevelMaxPool + +__all__ = [ + "ResNeSt", + "build_resnest_backbone", + "build_resnest_fpn_backbone", +] + + +class ResNetBlockBase(nn.Module): + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + FrozenBatchNorm2d.convert_frozen_batchnorm(self) + return self + + +class BasicBlock(ResNetBlockBase): + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + """ + The standard block type for ResNet18 and ResNet34. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): A callable that takes the number of + channels and returns a `nn.Module`, or a pre-defined string + (one of {"FrozenBN", "BN", "GN"}). + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(ResNetBlockBase): + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + avd=False, + avg_down=False, + radix=2, + bottleneck_width=64, + ): + """ + Args: + norm (str or callable): a callable that takes the number of + channels and return a `nn.Module`, or a pre-defined string + (one of {"FrozenBN", "BN", "GN"}). + stride_in_1x1 (bool): when stride==2, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + """ + super().__init__(in_channels, out_channels, stride) + + self.avd = avd and (stride>1) + self.avg_down = avg_down + self.radix = radix + + cardinality = num_groups + group_width = int(bottleneck_channels * (bottleneck_width / 64.)) * cardinality + + if in_channels != out_channels: + if self.avg_down: + self.shortcut_avgpool = nn.AvgPool2d(kernel_size=stride, stride=stride, + ceil_mode=True, count_include_pad=False) + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + group_width, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, group_width), + ) + + if self.radix>1: + from .splat import SplAtConv2d + self.conv2 = SplAtConv2d( + group_width, group_width, kernel_size=3, + stride = 1 if self.avd else stride_3x3, + padding=dilation, dilation=dilation, + groups=cardinality, bias=False, + radix=self.radix, + norm=norm, + ) + else: + self.conv2 = Conv2d( + group_width, + group_width, + kernel_size=3, + stride=1 if self.avd else stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, group_width), + ) + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + + self.conv3 = Conv2d( + group_width, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + if self.radix>1: + for layer in [self.conv1, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + else: + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + if self.radix>1: + out = self.conv2(out) + else: + out = self.conv2(out) + out = F.relu_(out) + + if self.avd: + out = self.avd_layer(out) + + out = self.conv3(out) + + if self.shortcut is not None: + if self.avg_down: + x = self.shortcut_avgpool(x) + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(ResNetBlockBase): + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + deform_modulated=False, + deform_num_groups=1, + avd=False, + avg_down=False, + radix=2, + bottleneck_width=64, + ): + """ + Similar to :class:`BottleneckBlock`, but with deformable conv in the 3x3 convolution. + """ + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + self.avd = avd and (stride>1) + self.avg_down = avg_down + self.radix = radix + + cardinality = num_groups + group_width = int(bottleneck_channels * (bottleneck_width / 64.)) * cardinality + + if in_channels != out_channels: + if self.avg_down: + self.shortcut_avgpool = nn.AvgPool2d(kernel_size=stride, stride=stride, + ceil_mode=True, count_include_pad=False) + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + group_width, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, group_width), + ) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + self.conv2_offset = Conv2d( + bottleneck_channels, + offset_channels * deform_num_groups, + kernel_size=3, + stride=1 if self.avd else stride_3x3, + padding=1 * dilation, + dilation=dilation, + groups=deform_num_groups, + ) + if self.radix>1: + from .splat import SplAtConv2d_dcn + self.conv2 = SplAtConv2d_dcn( + group_width, group_width, kernel_size=3, + stride = 1 if self.avd else stride_3x3, + padding=dilation, dilation=dilation, + groups=cardinality, bias=False, + radix=self.radix, + norm=norm, + deform_conv_op=deform_conv_op, + deformable_groups=deform_num_groups, + deform_modulated=deform_modulated, + + ) + else: + self.conv2 = deform_conv_op( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=1 if self.avd else stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + norm=get_norm(norm, bottleneck_channels), + ) + + if self.avd: + self.avd_layer = nn.AvgPool2d(3, stride, padding=1) + + self.conv3 = Conv2d( + group_width, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + if self.radix>1: + for layer in [self.conv1, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + else: + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + nn.init.constant_(self.conv2_offset.weight, 0) + nn.init.constant_(self.conv2_offset.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + if self.radix>1: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + else: + if self.deform_modulated: + offset_mask = self.conv2_offset(out) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = F.relu_(out) + + if self.avd: + out = self.avd_layer(out) + + out = self.conv3(out) + + if self.shortcut is not None: + if self.avg_down: + x = self.shortcut_avgpool(x) + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +def make_stage(block_class, num_blocks, first_stride, **kwargs): + """ + Create a resnet stage by creating many blocks. + + Args: + block_class (class): a subclass of ResNetBlockBase + num_blocks (int): + first_stride (int): the stride of the first block. The other blocks will have stride=1. + A `stride` argument will be passed to the block constructor. + kwargs: other arguments passed to the block constructor. + + Returns: + list[nn.Module]: a list of block module. + """ + blocks = [] + for i in range(num_blocks): + blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs)) + kwargs["in_channels"] = kwargs["out_channels"] + return blocks + + +class BasicStem(nn.Module): + def __init__(self, in_channels=3, out_channels=64, norm="BN", + deep_stem=False, stem_width=32): + """ + Args: + norm (str or callable): a callable that takes the number of + channels and return a `nn.Module`, or a pre-defined string + (one of {"FrozenBN", "BN", "GN"}). + """ + super().__init__() + self.deep_stem = deep_stem + + if self.deep_stem: + self.conv1_1 = Conv2d(3, stem_width, kernel_size=3, stride=2, + padding=1, bias=False, + norm=get_norm(norm, stem_width), + ) + self.conv1_2 = Conv2d(stem_width, stem_width, kernel_size=3, stride=1, + padding=1, bias=False, + norm=get_norm(norm, stem_width), + ) + self.conv1_3 = Conv2d(stem_width, stem_width*2, kernel_size=3, stride=1, + padding=1, bias=False, + norm=get_norm(norm, stem_width*2), + ) + for layer in [self.conv1_1, self.conv1_2, self.conv1_3]: + if layer is not None: + weight_init.c2_msra_fill(layer) + else: + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + + def forward(self, x): + if self.deep_stem: + x = self.conv1_1(x) + x = F.relu_(x) + x = self.conv1_2(x) + x = F.relu_(x) + x = self.conv1_3(x) + x = F.relu_(x) + else: + x = self.conv1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + @property + def out_channels(self): + if self.deep_stem: + return self.conv1_3.out_channels + else: + return self.conv1.out_channels + + @property + def stride(self): + return 4 # = stride 2 conv -> stride 2 max pool + + +class ResNeSt(Backbone): + def __init__(self, stem, stages, num_classes=None, out_features=None): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[ResNetBlock]]): several (typically 4) stages, + each contains multiple :class:`ResNetBlockBase`. + num_classes (None or int): if None, will not perform classification. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + """ + super(ResNeSt, self).__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stages_and_names = [] + for i, blocks in enumerate(stages): + for block in blocks: + assert isinstance(block, ResNetBlockBase), block + curr_channels = block.out_channels + stage = nn.Sequential(*blocks) + name = "res" + str(i + 2) + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = blocks[-1].out_channels + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + + def forward(self, x): + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + +@BACKBONE_REGISTRY.register() +def build_resnest_backbone(cfg, input_shape): + """ + Create a ResNeSt instance from config. + + Returns: + ResNeSt: a :class:`ResNeSt` instance. + """ + + depth = cfg.MODEL.RESNETS.DEPTH + stem_width = {50: 32, 101: 64, 152: 64, 200: 64, 269: 64}[depth] + radix = cfg.MODEL.RESNETS.RADIX + deep_stem = cfg.MODEL.RESNETS.DEEP_STEM or (radix > 1) + + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + deep_stem=deep_stem, + stem_width=stem_width, + ) + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + + if freeze_at >= 1: + for p in stem.parameters(): + p.requires_grad = False + stem = FrozenBatchNorm2d.convert_frozen_batchnorm(stem) + + # fmt: off + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + avd = cfg.MODEL.RESNETS.AVD or (radix > 1) + avg_down = cfg.MODEL.RESNETS.AVG_DOWN or (radix > 1) + bottleneck_width = cfg.MODEL.RESNETS.BOTTLENECK_WIDTH + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + 200: [3, 24, 36, 3], + 269: [3, 30, 48, 8], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert not any( + deform_on_per_stage + ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" + assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + in_channels = 2*stem_width if deep_stem else in_channels + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "first_stride": first_stride, + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + "avd": avd, + "avg_down": avg_down, + "radix": radix, + "bottleneck_width": bottleneck_width, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + + if freeze_at >= stage_idx: + for block in blocks: + block.freeze() + stages.append(blocks) + return ResNeSt(stem, stages, out_features=out_features) + +@BACKBONE_REGISTRY.register() +def build_resnest_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnest_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelMaxPool(), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/resnest/d2/splat.py b/resnest/d2/splat.py new file mode 100644 index 0000000..b48f94c --- /dev/null +++ b/resnest/d2/splat.py @@ -0,0 +1,179 @@ +"""Split-Attention""" + +import torch +from torch import nn +import torch.nn.functional as F +from torch.nn import Module, Linear, BatchNorm2d, ReLU +from torch.nn.modules.utils import _pair + +from detectron2.layers import ( + Conv2d, + get_norm, +) + +__all__ = ['SplAtConv2d', 'SplAtConv2d_dcn'] + +class SplAtConv2d(Module): + """Split-Attention Conv2d + """ + def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), + dilation=(1, 1), groups=1, bias=True, + radix=2, reduction_factor=4, + rectify=False, rectify_avg=False, norm=None, + dropblock_prob=0.0, **kwargs): + super(SplAtConv2d, self).__init__() + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels*radix//reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + from rfconv import RFConv2d + self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, + groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) + else: + self.conv = Conv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, + groups=groups*radix, bias=bias, **kwargs) + self.use_bn = norm is not None + if self.use_bn: + self.bn0 = get_norm(norm, channels*radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = get_norm(norm, inter_channels) + self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x): + x = self.conv(x) + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splited = torch.split(x, rchannel//self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = torch.split(atten, rchannel//self.radix, dim=1) + out = sum([att*split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous() + +class rSoftMax(nn.Module): + def __init__(self, radix, cardinality): + super().__init__() + self.radix = radix + self.cardinality = cardinality + + def forward(self, x): + batch = x.size(0) + if self.radix > 1: + x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) + x = F.softmax(x, dim=1) + x = x.reshape(batch, -1) + else: + x = torch.sigmoid(x) + return x + + +class SplAtConv2d_dcn(Module): + """Split-Attention Conv2d with dcn + """ + def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), + dilation=(1, 1), groups=1, bias=True, + radix=2, reduction_factor=4, + rectify=False, rectify_avg=False, norm=None, + dropblock_prob=0.0, + deform_conv_op=None, + deformable_groups=1, + deform_modulated=False, + **kwargs): + super(SplAtConv2d_dcn, self).__init__() + self.deform_modulated = deform_modulated + + padding = _pair(padding) + self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) + self.rectify_avg = rectify_avg + inter_channels = max(in_channels*radix//reduction_factor, 32) + self.radix = radix + self.cardinality = groups + self.channels = channels + self.dropblock_prob = dropblock_prob + if self.rectify: + from rfconv import RFConv2d + self.conv = RFConv2d(in_channels, channels*radix, kernel_size, stride, padding, dilation, + groups=groups*radix, bias=bias, average_mode=rectify_avg, **kwargs) + else: + self.conv = deform_conv_op(in_channels, channels*radix, kernel_size, stride, padding[0], dilation, + groups=groups*radix, bias=bias, deformable_groups=deformable_groups, **kwargs) + self.use_bn = norm is not None + if self.use_bn: + self.bn0 = get_norm(norm, channels*radix) + self.relu = ReLU(inplace=True) + self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) + if self.use_bn: + self.bn1 = get_norm(norm, inter_channels) + self.fc2 = Conv2d(inter_channels, channels*radix, 1, groups=self.cardinality) + if dropblock_prob > 0.0: + self.dropblock = DropBlock2D(dropblock_prob, 3) + self.rsoftmax = rSoftMax(radix, groups) + + def forward(self, x, offset_input): + + if self.deform_modulated: + offset_x, offset_y, mask = torch.chunk(offset_input, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + x = self.conv(x, offset, mask) + else: + x = self.conv(x, offset_input) + + if self.use_bn: + x = self.bn0(x) + if self.dropblock_prob > 0.0: + x = self.dropblock(x) + x = self.relu(x) + + batch, rchannel = x.shape[:2] + if self.radix > 1: + splited = torch.split(x, rchannel//self.radix, dim=1) + gap = sum(splited) + else: + gap = x + gap = F.adaptive_avg_pool2d(gap, 1) + gap = self.fc1(gap) + + if self.use_bn: + gap = self.bn1(gap) + gap = self.relu(gap) + + atten = self.fc2(gap) + atten = self.rsoftmax(atten).view(batch, -1, 1, 1) + + if self.radix > 1: + attens = torch.split(atten, rchannel//self.radix, dim=1) + out = sum([att*split for (att, split) in zip(attens, splited)]) + else: + out = atten * x + return out.contiguous()