Skip to content

Commit

Permalink
Add training script and using config file format (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 authored Mar 12, 2021
1 parent 11eb547 commit 52eda71
Show file tree
Hide file tree
Showing 24 changed files with 1,083 additions and 6 deletions.
2 changes: 2 additions & 0 deletions configs/Base-ResNet50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
MODEL:
NAME: 'resnet50'
File renamed without changes.
3 changes: 1 addition & 2 deletions resnest/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from .resnest import *
from .ablation import *
from .models import *
45 changes: 45 additions & 0 deletions resnest/torch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from fvcore.common.config import CfgNode as CN

_C = CN()

_C.SEED = 1

## data related
_C.DATA = CN()
_C.DATA.DATASET = 'ImageNet'
# assuming you've set up the dataset using provided script
_C.DATA.ROOT = os.path.expanduser('~/.encoding/data/ILSVRC2012')
_C.DATA.BASE_SIZE = None
_C.DATA.CROP_SIZE = 224
_C.DATA.LABEL_SMOOTHING = 0.0
_C.DATA.MIXUP = 0.0
_C.DATA.RAND_AUG = False

## model related
_C.MODEL = CN()
_C.MODEL.NAME = 'resnet50'
_C.MODEL.FINAL_DROP = False

## training params
_C.TRAINING = CN()
# (per-gpu batch size)
_C.TRAINING.BATCH_SIZE = 64
_C.TRAINING.TEST_BATCH_SIZE = 256
_C.TRAINING.LAST_GAMMA = False
_C.TRAINING.EPOCHS = 120
_C.TRAINING.START_EPOCHS = 0
_C.TRAINING.WORKERS = 4

## optimizer params
_C.OPTIMIZER = CN()
# (per-gpu lr)
_C.OPTIMIZER.LR = 0.025
_C.OPTIMIZER.LR_SCHEDULER = 'cos'
_C.OPTIMIZER.MOMENTUM = 0.9
_C.OPTIMIZER.WEIGHT_DECAY = 1e-4
_C.OPTIMIZER.DISABLE_BN_WD = False
_C.OPTIMIZER.WARMUP_EPOCHS = 0

def get_cfg() -> CN:
return _C.clone()
2 changes: 2 additions & 0 deletions resnest/torch/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .build import get_dataset, RESNEST_DATASETS_REGISTRY
from .imagenet import ImageNet
6 changes: 6 additions & 0 deletions resnest/torch/datasets/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from fvcore.common.registry import Registry

RESNEST_DATASETS_REGISTRY = Registry('RESNEST_DATASETS')

def get_dataset(dataset_name):
return RESNEST_DATASETS_REGISTRY.get(dataset_name)
25 changes: 25 additions & 0 deletions resnest/torch/datasets/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: [email protected]
## Copyright (c) 2018
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import warnings
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)

from .build import RESNEST_DATASETS_REGISTRY

@RESNEST_DATASETS_REGISTRY.register()
class ImageNet(datasets.ImageFolder):
def __init__(self, root=os.path.expanduser('~/.encoding/data/ILSVRC2012'), transform=None,
target_transform=None, train=True, **kwargs):
split='train' if train == True else 'val'
root = os.path.join(root, split)
super().__init__(root, transform, target_transform)
63 changes: 63 additions & 0 deletions resnest/torch/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
from resnest.torch.utils import MixUpWrapper

__all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'get_criterion']

def get_criterion(cfg, train_loader, gpu):
if cfg.DATA.MIXUP > 0:
train_loader = MixUpWrapper(cfg.DATA.MIXUP, 1000, train_loader, gpu)
criterion = NLLMultiLabelSmooth(cfg.DATA.LABEL_SMOOTHING)
elif cfg.DATA.LABEL_SMOOTHING > 0.0:
criterion = LabelSmoothing(cfg.DATA.LABEL_SMOOTHING)
else:
criterion = torch.nn.CrossEntropyLoss()
return criterion, train_loader

class LabelSmoothing(nn.Module):
"""
NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing

def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)

nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()

class NLLMultiLabelSmooth(nn.Module):
def __init__(self, smoothing = 0.1):
super(NLLMultiLabelSmooth, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing

def forward(self, x, target):
if self.training:
x = x.float()
target = target.float()
logprobs = torch.nn.functional.log_softmax(x, dim = -1)

nll_loss = -logprobs * target
nll_loss = nll_loss.sum(-1)

smooth_loss = -logprobs.mean(dim=-1)

loss = self.confidence * nll_loss + self.smoothing * smooth_loss

return loss.mean()
else:
return torch.nn.functional.cross_entropy(x, target)

2 changes: 2 additions & 0 deletions resnest/torch/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .resnest import *
from .ablation import *
File renamed without changes.
6 changes: 6 additions & 0 deletions resnest/torch/models/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from fvcore.common.registry import Registry

RESNEST_MODELS_REGISTRY = Registry('RESNEST_MODELS')

def get_model(model_name):
return RESNEST_MODELS_REGISTRY.get(model_name)
5 changes: 5 additions & 0 deletions resnest/torch/resnest.py → resnest/torch/models/resnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .resnet import ResNet, Bottleneck

__all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269']
from .build import RESNEST_MODELS_REGISTRY

_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth'

Expand All @@ -30,6 +31,7 @@ def short_hash(name):
name in _model_sha256.keys()
}

@RESNEST_MODELS_REGISTRY.register()
def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3],
radix=2, groups=1, bottleneck_width=64,
Expand All @@ -40,6 +42,7 @@ def resnest50(pretrained=False, root='~/.encoding/models', **kwargs):
resnest_model_urls['resnest50'], progress=True, check_hash=True))
return model

@RESNEST_MODELS_REGISTRY.register()
def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 4, 23, 3],
radix=2, groups=1, bottleneck_width=64,
Expand All @@ -50,6 +53,7 @@ def resnest101(pretrained=False, root='~/.encoding/models', **kwargs):
resnest_model_urls['resnest101'], progress=True, check_hash=True))
return model

@RESNEST_MODELS_REGISTRY.register()
def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 24, 36, 3],
radix=2, groups=1, bottleneck_width=64,
Expand All @@ -60,6 +64,7 @@ def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
resnest_model_urls['resnest200'], progress=True, check_hash=True))
return model

@RESNEST_MODELS_REGISTRY.register()
def resnest269(pretrained=False, root='~/.encoding/models', **kwargs):
model = ResNet(Bottleneck, [3, 30, 48, 8],
radix=2, groups=1, bottleneck_width=64,
Expand Down
55 changes: 54 additions & 1 deletion resnest/torch/resnet.py → resnest/torch/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,25 @@
import torch.nn as nn

from .splat import SplAtConv2d
from .build import RESNEST_MODELS_REGISTRY

__all__ = ['ResNet', 'Bottleneck']

_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth'

_model_sha256 = {name: checksum for checksum, name in [
]}


def short_hash(name):
if name not in _model_sha256:
raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
return _model_sha256[name][:8]

resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for
name in _model_sha256.keys()
}

class DropBlock2D(object):
def __init__(self, *args, **kwargs):
raise NotImplementedError
Expand Down Expand Up @@ -296,10 +312,47 @@ def forward(self, x):
x = self.layer4(x)

x = self.avgpool(x)
#x = x.view(x.size(0), -1)
x = torch.flatten(x, 1)
if self.drop:
x = self.drop(x)
x = self.fc(x)

return x

@RESNEST_MODELS_REGISTRY.register()
def resnet50(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(
resnest_model_urls['resnet50'], progress=True, check_hash=True))
return model


@RESNEST_MODELS_REGISTRY.register()
def resnet101(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(
resnest_model_urls['resnet101'], progress=True, check_hash=True))
return model


@RESNEST_MODELS_REGISTRY.register()
def resnet152(pretrained=False, root='~/.encoding/models', **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(
resnest_model_urls['resnet152'], progress=True, check_hash=True))
return model
File renamed without changes.
1 change: 1 addition & 0 deletions resnest/torch/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .build import get_transform, RESNEST_TRANSFORMS_REGISTRY
Loading

0 comments on commit 52eda71

Please sign in to comment.