A PyTorch implementation of DANet based on CVPR 2019 paper Dual Attention Network for Scene Segmentation.
- Anaconda
- PyTorch
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch
- opencv
pip install opencv-python
- tensorboard
pip install tensorboard
- pycocotools
pip install git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
- fvcore
pip install git+https://github.com/facebookresearch/fvcore
- cityscapesScripts
pip install git+https://github.com/mcordts/cityscapesScripts.git
- detectron2
pip install git+https://github.com/facebookresearch/detectron2.git@master
For a few datasets that detectron2 natively supports, the datasets are assumed to exist in a directory called
datasets/
, under the directory where you launch the program. They need to have the following directory structure:
cityscapes/
gtFine/
train/
aachen/
color.png, instanceIds.png, labelIds.png, polygons.json,
labelTrainIds.png
...
val/
test/
leftImg8bit/
train/
val/
test/
run ./datasets/prepare_cityscapes.py
to creat labelTrainIds.png
.
To train a model, run
python train_net.py --config-file <config.yaml>
For example, to launch end-to-end DANet training with ResNet-50 backbone on 8 GPUs, one should execute:
python train_net.py --config-file configs/r50.yaml --num-gpus 8
Model evaluation can be done similarly:
python train_net.py --config-file configs/r50.yaml --num-gpus 8 --eval-only MODEL.WEIGHTS checkpoints/model.pth
There are some difference between this implementation and official implementation:
- No
Multi-Grid
andMulti-Scale Testing
; - The image sizes of
Multi-Scale Training
are (800, 832, 864, 896, 928, 960); - Training step is set to
24000
; - Learning rate policy is
WarmupMultiStepLR
; Position Attention Module (PAM)
uses the similar mechanism asChannel Attention Module (CAM)
, just uses the tensor and its transpose to compute attention.
Name | train time (s/iter) | inference time (s/im) | train mem (GB) | PA % |
mean PA % | mean IoU % | FW IoU % | download link |
---|---|---|---|---|---|---|---|---|
R50 | 0.49 | 0.12 | 27.12 | 94.19 | 75.31 | 66.64 | 89.54 | model | ga7k |
R101 | 0.65 | 0.16 | 28.81 | 94.29 | 76.08 | 67.57 | 89.69 | model | xnvs |