Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* rebase to master

* little modification

* augmentation implemented with torch

* use cuda11.3+torch11, torchrun

* refactor

* support other value of ignore label

* discard distributed
  • Loading branch information
CoinCheung authored Jun 25, 2022
1 parent f9231b7 commit a54e37a
Show file tree
Hide file tree
Showing 14 changed files with 121 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ tensorrt/build/*
datasets/coco/train.txt
datasets/coco/val.txt
pretrained/*
dist_train.sh
run.sh
openvino/build/*
openvino/output*
*.onnx
Expand Down
36 changes: 5 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ Triton Inference Server(TIS) provides a service solution of deployment. You can
My platform is like this:

* ubuntu 18.04
* nvidia Tesla T4 gpu, driver 450.51.05
* cuda 10.2
* cudnn 7
* nvidia Tesla T4 gpu, driver 450.80.02
* cuda 10.2/11.3
* cudnn 8
* miniconda python 3.8.8
* pytorch 1.8.1
* pytorch 1.11.0


## get start
Expand Down Expand Up @@ -114,33 +114,7 @@ Then you need to change the field of `im_root` and `train/val_im_anns` in the co

## train

I used the following command to train the models:

```bash
# bisenetv1 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv1_city.py
NGPUS=2
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv2 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv2_city.py
NGPUS=2
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv1 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3
cfg_file=configs/bisenetv1_coco.py
NGPUS=4
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file

# bisenetv2 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
cfg_file=configs/bisenetv2_coco.py
NGPUS=8
python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file
```
Training commands I used to train the models can be found in [here](./dist_train.sh).

Note:
1. though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
Expand Down
32 changes: 32 additions & 0 deletions dist_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

'''
NOTE: replace torchrun with torch.distributed.launch if you use older version of pytorch. I suggest you use the same version as I do since I have not tested compatibility with older version after updating.
'''


## bisenetv1 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv1_city.py
NGPUS=2
torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file


## bisenetv2 cityscapes
export CUDA_VISIBLE_DEVICES=0,1
cfg_file=configs/bisenetv2_city.py
NGPUS=2
torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file


## bisenetv1 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3
cfg_file=configs/bisenetv1_coco.py
NGPUS=4
torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file


## bisenetv2 cocostuff
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
cfg_file=configs/bisenetv2_coco.py
NGPUS=8
torchrun --nproc_per_node=$NGPUS tools/train_amp.py --config $cfg_file
4 changes: 3 additions & 1 deletion lib/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
self.mode = mode
self.trans_func = trans_func

self.lb_ignore = -100
self.lb_map = None

with open(annpath, 'r') as fr:
Expand All @@ -50,7 +51,8 @@ def __getitem__(self, idx):
return img.detach(), label.unsqueeze(0).detach()

def get_image(self, impth, lbpth):
img, label = cv2.imread(impth)[:, :, ::-1], cv2.imread(lbpth, 0)
img = cv2.imread(impth)[:, :, ::-1].copy()
label = cv2.imread(lbpth, 0)
return img, label

def __len__(self):
Expand Down
4 changes: 2 additions & 2 deletions lib/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
class CocoStuff(BaseDataset):

def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
super(CocoStuff, self).__init__(dataroot, annpath, trans_func, mode)
super(CocoStuff, self).__init__(
dataroot, annpath, trans_func, mode)
self.n_cats = 171 # 91 stuff, 91 thing, 11 of thing have no annos
self.lb_ignore = 255

## label mapping, remove non-existing labels
missing = [11, 25, 28, 29, 44, 65, 67, 68, 70, 82, 90]
Expand Down
31 changes: 4 additions & 27 deletions lib/get_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,25 @@



class TransformationTrain(object):

def __init__(self, scales, cropsize):
self.trans_func = T.Compose([
T.RandomResizedCrop(scales, cropsize),
T.RandomHorizontalFlip(),
T.ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4
),
])

def __call__(self, im_lb):
im_lb = self.trans_func(im_lb)
return im_lb


class TransformationVal(object):

def __call__(self, im_lb):
im, lb = im_lb['im'], im_lb['lb']
return dict(im=im, lb=lb)


def get_data_loader(cfg, mode='train', distributed=True):
def get_data_loader(cfg, mode='train'):
if mode == 'train':
trans_func = TransformationTrain(cfg.scales, cfg.cropsize)
trans_func = T.TransformationTrain(cfg.scales, cfg.cropsize)
batchsize = cfg.ims_per_gpu
annpath = cfg.train_im_anns
shuffle = True
drop_last = True
elif mode == 'val':
trans_func = TransformationVal()
trans_func = T.TransformationVal()
batchsize = cfg.eval_ims_per_gpu
annpath = cfg.val_im_anns
shuffle = False
drop_last = False

ds = eval(cfg.dataset)(cfg.im_root, annpath, trans_func=trans_func, mode=mode)

if distributed:
if dist.is_initialized():
assert dist.is_available(), "dist should be initialzed"
if mode == 'train':
assert not cfg.max_iter is None
Expand Down
1 change: 1 addition & 0 deletions lib/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def setup_logger(name, logpth):
try:
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile, force=True)
except Exception:
for hl in logging.root.handlers: logging.root.removeHandler(hl)
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())

Expand Down
18 changes: 9 additions & 9 deletions lib/ohem_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@
# import ohem_cpp
# class OhemCELoss(nn.Module):
#
# def __init__(self, thresh, ignore_lb=255):
# def __init__(self, thresh, lb_ignore=255):
# super(OhemCELoss, self).__init__()
# self.score_thresh = thresh
# self.ignore_lb = ignore_lb
# self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='mean')
# self.lb_ignore = lb_ignore
# self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='mean')
#
# def forward(self, logits, labels):
# n_min = labels[labels != self.ignore_lb].numel() // 16
# n_min = labels[labels != self.lb_ignore].numel() // 16
# labels = ohem_cpp.score_ohem_label(
# logits, labels, self.ignore_lb, self.score_thresh, n_min).detach()
# logits, labels, self.lb_ignore, self.score_thresh, n_min).detach()
# loss = self.criteria(logits, labels)
# return loss


class OhemCELoss(nn.Module):

def __init__(self, thresh, ignore_lb=255):
def __init__(self, thresh, lb_ignore=255):
super(OhemCELoss, self).__init__()
self.thresh = -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda()
self.ignore_lb = ignore_lb
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
self.lb_ignore = lb_ignore
self.criteria = nn.CrossEntropyLoss(ignore_index=lb_ignore, reduction='none')

def forward(self, logits, labels):
n_min = labels[labels != self.ignore_lb].numel() // 16
n_min = labels[labels != self.lb_ignore].numel() // 16
loss = self.criteria(logits, labels).view(-1)
loss_hard = loss[loss > self.thresh]
if loss_hard.numel() < n_min:
Expand Down
24 changes: 24 additions & 0 deletions lib/transform_cv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,30 @@ def __call__(self, im_lb):
return im_lb


class TransformationTrain(object):

def __init__(self, scales, cropsize):
self.trans_func = Compose([
RandomResizedCrop(scales, cropsize),
RandomHorizontalFlip(),
ColorJitter(
brightness=0.4,
contrast=0.4,
saturation=0.4
),
])

def __call__(self, im_lb):
im_lb = self.trans_func(im_lb)
return im_lb


class TransformationVal(object):

def __call__(self, im_lb):
im, lb = im_lb['im'], im_lb['lb']
return dict(im=im, lb=lb)



if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion openvino/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <inference_engine.hpp>


std::string mdpth("../output_v2/model_v2.xml");
std::string mdpth("../output_v2/model_v2_city.xml");
std::string device("CPU"); // GNA does not support argmax, my cpu does not has integrated gpu
std::string impth("../../example.png");
std::string savepth("./res.jpg");
Expand Down
14 changes: 14 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@


export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export CUDA_VISIBLE_DEVICES=2,3
PORT=12333
NGPUS=2
cfg=configs/bisenetv1_city.py
# cfg=configs/bisenetv2_city.py
# cfg=configs/bisenetv1_coco.py
# cfg=configs/bisenetv2_coco.py

torchrun --nproc_per_node=$NGPUS --master_port $PORT tools/train_amp.py --config $cfg
# python -m torch.distributed.launch --use_env --nproc_per_node=$NGPUS --master_port $PORT tools/train_amp.py --config $cfg

4 changes: 4 additions & 0 deletions tools/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from lib.models import model_factory
from configs import set_cfg_from_file


# uncomment the following line if you want to reduce cpu usage, see issue #231
# torch.set_num_threads(4)

torch.set_grad_enabled(False)
np.random.seed(123)

Expand Down
20 changes: 10 additions & 10 deletions tools/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,17 @@ def evaluate(cfg, weight_pth):
net.load_state_dict(torch.load(weight_pth, map_location='cpu'))
net.cuda()

is_dist = dist.is_initialized()
if is_dist:
local_rank = dist.get_rank()
net = nn.parallel.DistributedDataParallel(
net,
device_ids=[local_rank, ],
output_device=local_rank
)
# is_dist = dist.is_initialized()
# if is_dist:
# local_rank = dist.get_rank()
# net = nn.parallel.DistributedDataParallel(
# net,
# device_ids=[local_rank, ],
# output_device=local_rank
# )

## evaluator
heads, mious = eval_model(cfg, net.module)
heads, mious = eval_model(cfg, net)
logger.info(tabulate([mious, ], headers=heads, tablefmt='orgtbl'))


Expand All @@ -284,7 +284,7 @@ def main():
init_method='tcp://127.0.0.1:{}'.format(args.port),
world_size=torch.cuda.device_count(),
rank=args.local_rank
)
)
if not osp.exists(cfg.respth): os.makedirs(cfg.respth)
setup_logger('{}-eval'.format(cfg.model_type), cfg.respth)
evaluate(cfg, args.weight_pth)
Expand Down
27 changes: 11 additions & 16 deletions tools/train_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@

def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--local_rank', dest='local_rank', type=int, default=-1,)
parse.add_argument('--port', dest='port', type=int, default=44554,)
parse.add_argument('--config', dest='config', type=str,
default='configs/bisenetv2.py',)
parse.add_argument('--finetune-from', type=str, default=None,)
Expand All @@ -54,7 +52,7 @@ def parse_args():



def set_model():
def set_model(lb_ignore=255):
logger = logging.getLogger()
net = model_factory[cfg.model_type](cfg.n_cats)
if not args.finetune_from is None:
Expand All @@ -63,8 +61,9 @@ def set_model():
if cfg.use_sync_bn: net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
net.cuda()
net.train()
criteria_pre = OhemCELoss(0.7)
criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
criteria_pre = OhemCELoss(0.7, lb_ignore)
criteria_aux = [OhemCELoss(0.7, lb_ignore)
for _ in range(cfg.num_aux_heads)]
return net, criteria_pre, criteria_aux


Expand Down Expand Up @@ -100,7 +99,7 @@ def set_optimizer(model):


def set_model_dist(net):
local_rank = dist.get_rank()
local_rank = int(os.environ['LOCAL_RANK'])
net = nn.parallel.DistributedDataParallel(
net,
device_ids=[local_rank, ],
Expand All @@ -122,13 +121,12 @@ def set_meters():

def train():
logger = logging.getLogger()
is_dist = dist.is_initialized()

## dataset
dl = get_data_loader(cfg, mode='train', distributed=is_dist)
dl = get_data_loader(cfg, mode='train')

## model
net, criteria_pre, criteria_aux = set_model()
net, criteria_pre, criteria_aux = set_model(dl.dataset.lb_ignore)

## optimizer
optim = set_optimizer(net)
Expand Down Expand Up @@ -194,13 +192,10 @@ def train():


def main():
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:{}'.format(args.port),
world_size=torch.cuda.device_count(),
rank=args.local_rank
)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')

if not osp.exists(cfg.respth): os.makedirs(cfg.respth)
setup_logger(f'{cfg.model_type}-{cfg.dataset.lower()}-train', cfg.respth)
train()
Expand Down

0 comments on commit a54e37a

Please sign in to comment.