Skip to content

Commit

Permalink
load pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Jun 11, 2021
1 parent 90de8f9 commit 9ff609a
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 13 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,8 @@ preprocess_data.py
res/
adj.md
tensorrt/build/*
datasets/coco/train.txt
datasets/coco/val.txt
pretrained/*
lib/coco.py

4 changes: 2 additions & 2 deletions configs/bisenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
cfg = dict(
model_type='bisenetv2',
num_aux_heads=4,
lr_start = 5e-2,
lr_start = 1 * 5e-3,
weight_decay=5e-4,
warmup_iters = 1000,
max_iter = 150000,
Expand All @@ -14,6 +14,6 @@
cropsize=[512, 1024],
ims_per_gpu=8,
use_fp16=True,
use_sync_bn=False,
use_sync_bn=True,
respth='./res',
)
2 changes: 1 addition & 1 deletion datasets/cityscapes/gtFine
2 changes: 1 addition & 1 deletion datasets/cityscapes/leftImg8bit
1 change: 1 addition & 0 deletions datasets/coco/images/train2017
1 change: 1 addition & 0 deletions datasets/coco/images/val2017
1 change: 1 addition & 0 deletions datasets/coco/labels/train2017
1 change: 1 addition & 0 deletions datasets/coco/labels/val2017
6 changes: 6 additions & 0 deletions dist_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

export CUDA_VISIBLE_DEVICES=6,7
PORT=52330
NGPUS=2

python -m torch.distributed.launch --nproc_per_node=$NGPUS tools/train_amp.py --model bisenetv2 --port $PORT
5 changes: 4 additions & 1 deletion lib/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def setup_logger(name, logpth):
log_level = logging.INFO
if dist.is_initialized() and dist.get_rank() != 0:
log_level = logging.WARNING
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
try:
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile, force=True)
except Exception:
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
logging.root.addHandler(logging.StreamHandler())


Expand Down
29 changes: 29 additions & 0 deletions lib/models/bisenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def init_weight(self):
nn.init.xavier_normal_(self.proj.weight, gain=1.)



class DetailBranch(nn.Module):

def __init__(self):
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(self, n_classes, output_aux=True):
self.aux5_4 = SegmentHead(128, 128, n_classes, up_factor=32)

self.init_weights()
self.load_pretrain()

def forward(self, x):
size = x.size()[2:]
Expand Down Expand Up @@ -353,6 +355,33 @@ def init_weights(self):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)

def load_pretrain(self):
state = torch.load('pretrained/bisenetv2_pretrain.pth', map_location='cpu')
state = {k:v for k,v in state.items() if not k in ('fc', 'head', 'dense_head')}
for name, child in self.named_children():
if name in state.keys():
child.load_state_dict(state[name])


def get_params(self):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
for name, param in self.named_parameters():
if 'head' in name or 'aux' in name:
if param.dim() == 1:
lr_mul_nowd_params.append(param)
elif param.dim() == 4:
lr_mul_wd_params.append(param)
else:
print(name)
else:
if param.dim() == 1:
nowd_params.append(param)
elif param.dim() == 4:
wd_params.append(param)
else:
print(name)
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params


if __name__ == "__main__":
# x = torch.randn(16, 3, 1024, 2048)
Expand Down
9 changes: 8 additions & 1 deletion lib/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RepeatedDistSampler(Sampler):
shuffle (optional): If true (default), sampler will shuffle the indices
"""

def __init__(self, dataset, num_imgs, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_imgs, num_replicas=None, rank=None, shuffle=True, ba=False):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand All @@ -40,6 +40,7 @@ def __init__(self, dataset, num_imgs, num_replicas=None, rank=None, shuffle=True
self.total_size = self.num_imgs_rank * self.num_replicas
self.num_imgs = num_imgs
self.shuffle = shuffle
self.ba = ba


def __iter__(self):
Expand All @@ -58,6 +59,12 @@ def __iter__(self):
indices = indices[:self.total_size]
assert len(indices) == self.total_size

if self.ba:
n_rep = max(4, self.num_replicas)
len_ind = len(indices) // n_rep + 1
indices = indices[:len_ind]
indices = [ind for ind in indices for _ in range(n_rep)]

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_imgs_rank
Expand Down
42 changes: 42 additions & 0 deletions tools/gen_coco_annos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

import os
import os.path as osp


def gen_coco():
'''
root_path:
|- images
|- train2017
|- val2017
|- labels
|- train2017
|- val2017
'''
root_path = '/datasets/coco'
save_path = './datasets/coco/'
for mode in ('train', 'val'):
im_root = osp.join(root_path, f'images/{mode}2017')
lb_root = osp.join(root_path, f'labels/{mode}2017')

ims = os.listdir(im_root)
lbs = os.listdir(lb_root)

print(len(ims))
print(len(lbs))

im_names = [el.replace('.jpg', '') for el in ims]
lb_names = [el.replace('.png', '') for el in lbs]
common_names = list(set(im_names) & set(lb_names))

lines = [
f'images/{mode}2017/{name}.jpg,labels/{mode}2017/{name}.png'
for name in common_names
]

with open(f'{save_path}/{mode}.txt', 'w') as fw:
fw.write('\n'.join(lines))



gen_coco()
16 changes: 9 additions & 7 deletions tools/train_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@


## fix all random seeds
torch.manual_seed(123)
torch.cuda.manual_seed(123)
np.random.seed(123)
random.seed(123)
torch.backends.cudnn.deterministic = True
# torch.manual_seed(123)
# torch.cuda.manual_seed(123)
# np.random.seed(123)
# random.seed(123)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
# torch.multiprocessing.set_sharing_strategy('file_system')

Expand Down Expand Up @@ -68,11 +68,13 @@ def set_model():
def set_optimizer(model):
if hasattr(model, 'get_params'):
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
# wd_val = cfg.weight_decay
wd_val = 0
params_list = [
{'params': wd_params, },
{'params': nowd_params, 'weight_decay': 0},
{'params': nowd_params, 'weight_decay': wd_val},
{'params': lr_mul_wd_params, 'lr': cfg.lr_start * 10},
{'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr': cfg.lr_start * 10},
{'params': lr_mul_nowd_params, 'weight_decay': wd_val, 'lr': cfg.lr_start * 10},
]
else:
wd_params, non_wd_params = [], []
Expand Down

0 comments on commit 9ff609a

Please sign in to comment.