From 52eda712db5d957c29b3560969793f6f115e7d94 Mon Sep 17 00:00:00 2001 From: Hang Zhang Date: Fri, 12 Mar 2021 14:32:56 -0800 Subject: [PATCH] Add training script and using config file format (#139) --- configs/Base-ResNet50.yaml | 2 + resnest/{ => gluon}/transforms.py | 0 resnest/torch/__init__.py | 3 +- resnest/torch/config.py | 45 ++++ resnest/torch/datasets/__init__.py | 2 + resnest/torch/datasets/build.py | 6 + resnest/torch/datasets/imagenet.py | 25 +++ resnest/torch/loss.py | 63 ++++++ resnest/torch/models/__init__.py | 2 + resnest/torch/{ => models}/ablation.py | 0 resnest/torch/models/build.py | 6 + resnest/torch/{ => models}/resnest.py | 5 + resnest/torch/{ => models}/resnet.py | 55 ++++- resnest/torch/{ => models}/splat.py | 0 resnest/torch/transforms/__init__.py | 1 + resnest/torch/transforms/autoaug.py | 197 +++++++++++++++++ resnest/torch/transforms/build.py | 53 +++++ resnest/torch/transforms/transforms.py | 122 +++++++++++ resnest/torch/utils.py | 204 +++++++++++++++++ scripts/gluon/train.py | 2 +- scripts/gluon/verify.py | 2 +- scripts/torch/train.py | 291 +++++++++++++++++++++++++ setup.py | 1 + tests/test_radix_major.py | 2 +- 24 files changed, 1083 insertions(+), 6 deletions(-) create mode 100644 configs/Base-ResNet50.yaml rename resnest/{ => gluon}/transforms.py (100%) create mode 100644 resnest/torch/config.py create mode 100644 resnest/torch/datasets/__init__.py create mode 100644 resnest/torch/datasets/build.py create mode 100644 resnest/torch/datasets/imagenet.py create mode 100644 resnest/torch/loss.py create mode 100644 resnest/torch/models/__init__.py rename resnest/torch/{ => models}/ablation.py (100%) create mode 100644 resnest/torch/models/build.py rename resnest/torch/{ => models}/resnest.py (93%) rename resnest/torch/{ => models}/resnet.py (87%) rename resnest/torch/{ => models}/splat.py (100%) create mode 100644 resnest/torch/transforms/__init__.py create mode 100644 resnest/torch/transforms/autoaug.py create mode 100644 resnest/torch/transforms/build.py create mode 100644 resnest/torch/transforms/transforms.py create mode 100644 resnest/torch/utils.py create mode 100644 scripts/torch/train.py diff --git a/configs/Base-ResNet50.yaml b/configs/Base-ResNet50.yaml new file mode 100644 index 0000000..8985efd --- /dev/null +++ b/configs/Base-ResNet50.yaml @@ -0,0 +1,2 @@ +MODEL: + NAME: 'resnet50' diff --git a/resnest/transforms.py b/resnest/gluon/transforms.py similarity index 100% rename from resnest/transforms.py rename to resnest/gluon/transforms.py diff --git a/resnest/torch/__init__.py b/resnest/torch/__init__.py index 2acf216..aed4fa3 100644 --- a/resnest/torch/__init__.py +++ b/resnest/torch/__init__.py @@ -1,2 +1 @@ -from .resnest import * -from .ablation import * +from .models import * diff --git a/resnest/torch/config.py b/resnest/torch/config.py new file mode 100644 index 0000000..68ece7a --- /dev/null +++ b/resnest/torch/config.py @@ -0,0 +1,45 @@ +import os +from fvcore.common.config import CfgNode as CN + +_C = CN() + +_C.SEED = 1 + +## data related +_C.DATA = CN() +_C.DATA.DATASET = 'ImageNet' +# assuming you've set up the dataset using provided script +_C.DATA.ROOT = os.path.expanduser('~/.encoding/data/ILSVRC2012') +_C.DATA.BASE_SIZE = None +_C.DATA.CROP_SIZE = 224 +_C.DATA.LABEL_SMOOTHING = 0.0 +_C.DATA.MIXUP = 0.0 +_C.DATA.RAND_AUG = False + +## model related +_C.MODEL = CN() +_C.MODEL.NAME = 'resnet50' +_C.MODEL.FINAL_DROP = False + +## training params +_C.TRAINING = CN() +# (per-gpu batch size) +_C.TRAINING.BATCH_SIZE = 64 +_C.TRAINING.TEST_BATCH_SIZE = 256 +_C.TRAINING.LAST_GAMMA = False +_C.TRAINING.EPOCHS = 120 +_C.TRAINING.START_EPOCHS = 0 +_C.TRAINING.WORKERS = 4 + +## optimizer params +_C.OPTIMIZER = CN() +# (per-gpu lr) +_C.OPTIMIZER.LR = 0.025 +_C.OPTIMIZER.LR_SCHEDULER = 'cos' +_C.OPTIMIZER.MOMENTUM = 0.9 +_C.OPTIMIZER.WEIGHT_DECAY = 1e-4 +_C.OPTIMIZER.DISABLE_BN_WD = False +_C.OPTIMIZER.WARMUP_EPOCHS = 0 + +def get_cfg() -> CN: + return _C.clone() diff --git a/resnest/torch/datasets/__init__.py b/resnest/torch/datasets/__init__.py new file mode 100644 index 0000000..ec907ff --- /dev/null +++ b/resnest/torch/datasets/__init__.py @@ -0,0 +1,2 @@ +from .build import get_dataset, RESNEST_DATASETS_REGISTRY +from .imagenet import ImageNet diff --git a/resnest/torch/datasets/build.py b/resnest/torch/datasets/build.py new file mode 100644 index 0000000..f00936d --- /dev/null +++ b/resnest/torch/datasets/build.py @@ -0,0 +1,6 @@ +from fvcore.common.registry import Registry + +RESNEST_DATASETS_REGISTRY = Registry('RESNEST_DATASETS') + +def get_dataset(dataset_name): + return RESNEST_DATASETS_REGISTRY.get(dataset_name) diff --git a/resnest/torch/datasets/imagenet.py b/resnest/torch/datasets/imagenet.py new file mode 100644 index 0000000..bb42b36 --- /dev/null +++ b/resnest/torch/datasets/imagenet.py @@ -0,0 +1,25 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2018 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +import warnings +warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) + +from .build import RESNEST_DATASETS_REGISTRY + +@RESNEST_DATASETS_REGISTRY.register() +class ImageNet(datasets.ImageFolder): + def __init__(self, root=os.path.expanduser('~/.encoding/data/ILSVRC2012'), transform=None, + target_transform=None, train=True, **kwargs): + split='train' if train == True else 'val' + root = os.path.join(root, split) + super().__init__(root, transform, target_transform) diff --git a/resnest/torch/loss.py b/resnest/torch/loss.py new file mode 100644 index 0000000..8c31655 --- /dev/null +++ b/resnest/torch/loss.py @@ -0,0 +1,63 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.autograd import Variable +from resnest.torch.utils import MixUpWrapper + +__all__ = ['LabelSmoothing', 'NLLMultiLabelSmooth', 'get_criterion'] + +def get_criterion(cfg, train_loader, gpu): + if cfg.DATA.MIXUP > 0: + train_loader = MixUpWrapper(cfg.DATA.MIXUP, 1000, train_loader, gpu) + criterion = NLLMultiLabelSmooth(cfg.DATA.LABEL_SMOOTHING) + elif cfg.DATA.LABEL_SMOOTHING > 0.0: + criterion = LabelSmoothing(cfg.DATA.LABEL_SMOOTHING) + else: + criterion = torch.nn.CrossEntropyLoss() + return criterion, train_loader + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1) + smooth_loss = -logprobs.mean(dim=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.mean() + +class NLLMultiLabelSmooth(nn.Module): + def __init__(self, smoothing = 0.1): + super(NLLMultiLabelSmooth, self).__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + if self.training: + x = x.float() + target = target.float() + logprobs = torch.nn.functional.log_softmax(x, dim = -1) + + nll_loss = -logprobs * target + nll_loss = nll_loss.sum(-1) + + smooth_loss = -logprobs.mean(dim=-1) + + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + + return loss.mean() + else: + return torch.nn.functional.cross_entropy(x, target) + diff --git a/resnest/torch/models/__init__.py b/resnest/torch/models/__init__.py new file mode 100644 index 0000000..2acf216 --- /dev/null +++ b/resnest/torch/models/__init__.py @@ -0,0 +1,2 @@ +from .resnest import * +from .ablation import * diff --git a/resnest/torch/ablation.py b/resnest/torch/models/ablation.py similarity index 100% rename from resnest/torch/ablation.py rename to resnest/torch/models/ablation.py diff --git a/resnest/torch/models/build.py b/resnest/torch/models/build.py new file mode 100644 index 0000000..26e7239 --- /dev/null +++ b/resnest/torch/models/build.py @@ -0,0 +1,6 @@ +from fvcore.common.registry import Registry + +RESNEST_MODELS_REGISTRY = Registry('RESNEST_MODELS') + +def get_model(model_name): + return RESNEST_MODELS_REGISTRY.get(model_name) diff --git a/resnest/torch/resnest.py b/resnest/torch/models/resnest.py similarity index 93% rename from resnest/torch/resnest.py rename to resnest/torch/models/resnest.py index 1a06b1e..ed3594f 100644 --- a/resnest/torch/resnest.py +++ b/resnest/torch/models/resnest.py @@ -11,6 +11,7 @@ from .resnet import ResNet, Bottleneck __all__ = ['resnest50', 'resnest101', 'resnest200', 'resnest269'] +from .build import RESNEST_MODELS_REGISTRY _url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' @@ -30,6 +31,7 @@ def short_hash(name): name in _model_sha256.keys() } +@RESNEST_MODELS_REGISTRY.register() def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): model = ResNet(Bottleneck, [3, 4, 6, 3], radix=2, groups=1, bottleneck_width=64, @@ -40,6 +42,7 @@ def resnest50(pretrained=False, root='~/.encoding/models', **kwargs): resnest_model_urls['resnest50'], progress=True, check_hash=True)) return model +@RESNEST_MODELS_REGISTRY.register() def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): model = ResNet(Bottleneck, [3, 4, 23, 3], radix=2, groups=1, bottleneck_width=64, @@ -50,6 +53,7 @@ def resnest101(pretrained=False, root='~/.encoding/models', **kwargs): resnest_model_urls['resnest101'], progress=True, check_hash=True)) return model +@RESNEST_MODELS_REGISTRY.register() def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): model = ResNet(Bottleneck, [3, 24, 36, 3], radix=2, groups=1, bottleneck_width=64, @@ -60,6 +64,7 @@ def resnest200(pretrained=False, root='~/.encoding/models', **kwargs): resnest_model_urls['resnest200'], progress=True, check_hash=True)) return model +@RESNEST_MODELS_REGISTRY.register() def resnest269(pretrained=False, root='~/.encoding/models', **kwargs): model = ResNet(Bottleneck, [3, 30, 48, 8], radix=2, groups=1, bottleneck_width=64, diff --git a/resnest/torch/resnet.py b/resnest/torch/models/resnet.py similarity index 87% rename from resnest/torch/resnet.py rename to resnest/torch/models/resnet.py index 26285de..609504f 100644 --- a/resnest/torch/resnet.py +++ b/resnest/torch/models/resnet.py @@ -11,9 +11,25 @@ import torch.nn as nn from .splat import SplAtConv2d +from .build import RESNEST_MODELS_REGISTRY __all__ = ['ResNet', 'Bottleneck'] +_url_format = 'https://s3.us-west-1.wasabisys.com/resnest/torch/{}-{}.pth' + +_model_sha256 = {name: checksum for checksum, name in [ + ]} + + +def short_hash(name): + if name not in _model_sha256: + raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) + return _model_sha256[name][:8] + +resnest_model_urls = {name: _url_format.format(name, short_hash(name)) for + name in _model_sha256.keys() +} + class DropBlock2D(object): def __init__(self, *args, **kwargs): raise NotImplementedError @@ -296,10 +312,47 @@ def forward(self, x): x = self.layer4(x) x = self.avgpool(x) - #x = x.view(x.size(0), -1) x = torch.flatten(x, 1) if self.drop: x = self.drop(x) x = self.fc(x) return x + +@RESNEST_MODELS_REGISTRY.register() +def resnet50(pretrained=False, root='~/.encoding/models', **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(torch.hub.load_state_dict_from_url( + resnest_model_urls['resnet50'], progress=True, check_hash=True)) + return model + + +@RESNEST_MODELS_REGISTRY.register() +def resnet101(pretrained=False, root='~/.encoding/models', **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(torch.hub.load_state_dict_from_url( + resnest_model_urls['resnet101'], progress=True, check_hash=True)) + return model + + +@RESNEST_MODELS_REGISTRY.register() +def resnet152(pretrained=False, root='~/.encoding/models', **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(torch.hub.load_state_dict_from_url( + resnest_model_urls['resnet152'], progress=True, check_hash=True)) + return model diff --git a/resnest/torch/splat.py b/resnest/torch/models/splat.py similarity index 100% rename from resnest/torch/splat.py rename to resnest/torch/models/splat.py diff --git a/resnest/torch/transforms/__init__.py b/resnest/torch/transforms/__init__.py new file mode 100644 index 0000000..f5e6f5c --- /dev/null +++ b/resnest/torch/transforms/__init__.py @@ -0,0 +1 @@ +from .build import get_transform, RESNEST_TRANSFORMS_REGISTRY diff --git a/resnest/torch/transforms/autoaug.py b/resnest/torch/transforms/autoaug.py new file mode 100644 index 0000000..3ecc697 --- /dev/null +++ b/resnest/torch/transforms/autoaug.py @@ -0,0 +1,197 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# code adapted from: +# https://github.com/kakaobrain/fast-autoaugment +# https://github.com/rpmcruz/autoaugment +import math +import random + +import numpy as np +from collections import defaultdict +import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw + +RESAMPLE_MODE=PIL.Image.BICUBIC + +RANDOM_MIRROR = True + +def ShearX(img, v, resample=RESAMPLE_MODE): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if RANDOM_MIRROR and random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0), + resample=resample) + +def ShearY(img, v, resample=RESAMPLE_MODE): # [-0.3, 0.3] + assert -0.3 <= v <= 0.3 + if RANDOM_MIRROR and random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0), + resample=resample) + + +def TranslateX(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if RANDOM_MIRROR and random.random() > 0.5: + v = -v + v = v * img.size[0] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), + resample=resample) + + +def TranslateY(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45] + assert -0.45 <= v <= 0.45 + if RANDOM_MIRROR and random.random() > 0.5: + v = -v + v = v * img.size[1] + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), + resample=resample) + + +def TranslateXabs(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), + resample=resample) + + +def TranslateYabs(img, v, resample=RESAMPLE_MODE): # [-150, 150] => percentage: [-0.45, 0.45] + assert 0 <= v + if random.random() > 0.5: + v = -v + return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), + resample=resample) + + +def Rotate(img, v): # [-30, 30] + assert -30 <= v <= 30 + if RANDOM_MIRROR and random.random() > 0.5: + v = -v + return img.rotate(v) + + +def AutoContrast(img, _): + return PIL.ImageOps.autocontrast(img) + + +def Invert(img, _): + return PIL.ImageOps.invert(img) + + +def Equalize(img, _): + return PIL.ImageOps.equalize(img) + + +def Flip(img, _): # not from the paper + return PIL.ImageOps.mirror(img) + + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return PIL.ImageOps.solarize(img, v) + + +def SolarizeAdd(img, addition=0, threshold=128): + img_np = np.array(img).astype(np.int) + img_np = img_np + addition + img_np = np.clip(img_np, 0, 255) + img_np = img_np.astype(np.uint8) + img = PIL.Image.fromarray(img_np) + return PIL.ImageOps.solarize(img, threshold) + + +def Posterize(img, v): # [4, 8] + #assert 4 <= v <= 8 + v = int(v) + return PIL.ImageOps.posterize(img, v) + +def Contrast(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Contrast(img).enhance(v) + + +def Color(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Color(img).enhance(v) + + +def Brightness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Brightness(img).enhance(v) + + +def Sharpness(img, v): # [0.1,1.9] + assert 0.1 <= v <= 1.9 + return PIL.ImageEnhance.Sharpness(img).enhance(v) + + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + # assert 0 <= v <= 20 + if v < 0: + return img + w, h = img.size + x0 = np.random.uniform(w) + y0 = np.random.uniform(h) + + x0 = int(max(0, x0 - v / 2.)) + y0 = int(max(0, y0 - v / 2.)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + # color = (0, 0, 0) + img = img.copy() + PIL.ImageDraw.Draw(img).rectangle(xy, color) + return img + + +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] + assert 0.0 <= v <= 0.2 + if v <= 0.: + return img + + v = v * img.size[0] + return CutoutAbs(img, v) + +def rand_augment_list(): # 16 oeprations and their ranges + l = [ + (AutoContrast, 0, 1), + (Equalize, 0, 1), + (Invert, 0, 1), + (Rotate, 0, 30), + (Posterize, 0, 4), + (Solarize, 0, 256), + (SolarizeAdd, 0, 110), + (Color, 0.1, 1.9), + (Contrast, 0.1, 1.9), + (Brightness, 0.1, 1.9), + (Sharpness, 0.1, 1.9), + (ShearX, 0., 0.3), + (ShearY, 0., 0.3), + (CutoutAbs, 0, 40), + (TranslateXabs, 0., 100), + (TranslateYabs, 0., 100), + ] + + return l + +class RandAugment(object): + def __init__(self, n, m): + self.n = n + self.m = m + self.augment_list = rand_augment_list() + + def __call__(self, img): + ops = random.choices(self.augment_list, k=self.n) + for op, minval, maxval in ops: + if random.random() > random.uniform(0.2, 0.8): + continue + val = (float(self.m) / 30) * float(maxval - minval) + minval + img = op(img, val) + return img diff --git a/resnest/torch/transforms/build.py b/resnest/torch/transforms/build.py new file mode 100644 index 0000000..95eaf33 --- /dev/null +++ b/resnest/torch/transforms/build.py @@ -0,0 +1,53 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +import torch +from torchvision.transforms import * +from .transforms import * +from fvcore.common.registry import Registry + +RESNEST_TRANSFORMS_REGISTRY = Registry('RESNEST_TRANSFORMS') + +def get_transform(dataset_name): + return RESNEST_TRANSFORMS_REGISTRY.get(dataset_name.lower()) + +@RESNEST_TRANSFORMS_REGISTRY.register() +def imagenet(base_size=None, crop_size=224, rand_aug=False): + normalize = Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + base_size = base_size if base_size is not None else int(1.0 * crop_size / 0.875) + train_transforms = [] + val_transforms = [] + if rand_aug: + from .autoaug import RandAugment + train_transforms.append(RandAugment(2, 12)) + + train_transforms.extend([ + ERandomCrop(crop_size), + RandomHorizontalFlip(), + ColorJitter(0.4, 0.4, 0.4), + ToTensor(), + Lighting(0.1, _imagenet_pca['eigval'], _imagenet_pca['eigvec']), + normalize, + ]) + val_transforms.extend([ + ECenterCrop(crop_size), + ToTensor(), + normalize, + ]) + transform_train = Compose(train_transforms) + transform_val = Compose(val_transforms) + return transform_train, transform_val + +_imagenet_pca = { + 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), + 'eigvec': torch.Tensor([ + [-0.5675, 0.7192, 0.4009], + [-0.5808, -0.0045, -0.8140], + [-0.5836, -0.6948, 0.4203], + ]) +} diff --git a/resnest/torch/transforms/transforms.py b/resnest/torch/transforms/transforms.py new file mode 100644 index 0000000..6ecd589 --- /dev/null +++ b/resnest/torch/transforms/transforms.py @@ -0,0 +1,122 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +import math +import random + +from PIL import Image +from torchvision.transforms import Resize, InterpolationMode + +__all__ = ['Lighting', 'ERandomCrop', 'ECenterCrop'] + +class Lighting(object): + """Lighting noise(AlexNet - style PCA - based noise)""" + + def __init__(self, alphastd, eigval, eigvec): + self.alphastd = alphastd + self.eigval = eigval + self.eigvec = eigvec + + def __call__(self, img): + if self.alphastd == 0: + return img + + alpha = img.new().resize_(3).normal_(0, self.alphastd) + rgb = self.eigvec.type_as(img).clone()\ + .mul(alpha.view(1, 3).expand(3, 3))\ + .mul(self.eigval.view(1, 3).expand(3, 3))\ + .sum(1).squeeze() + + return img.add(rgb.view(3, 1, 1).expand_as(img)) + + +#https://github.com/kakaobrain/fast-autoaugment/blob/master/FastAutoAugment/data.py +class ERandomCrop: + def __init__(self, imgsize, min_covered=0.1, aspect_ratio_range=(3./4, 4./3), + area_range=(0.1, 1.0), max_attempts=10): + assert 0.0 < min_covered + assert 0 < aspect_ratio_range[0] <= aspect_ratio_range[1] + assert 0 < area_range[0] <= area_range[1] + assert 1 <= max_attempts + + self.imgsize = imgsize + self.min_covered = min_covered + self.aspect_ratio_range = aspect_ratio_range + self.area_range = area_range + self.max_attempts = max_attempts + self._fallback = ECenterCrop(imgsize) + self.resize_method = Resize((imgsize, imgsize), + interpolation=InterpolationMode.BILINEAR) + + def __call__(self, img): + original_width, original_height = img.size + min_area = self.area_range[0] * (original_width * original_height) + max_area = self.area_range[1] * (original_width * original_height) + + for _ in range(self.max_attempts): + aspect_ratio = random.uniform(*self.aspect_ratio_range) + height = int(round(math.sqrt(min_area / aspect_ratio))) + max_height = int(round(math.sqrt(max_area / aspect_ratio))) + + if max_height * aspect_ratio > original_width: + max_height = (original_width + 0.5 - 1e-7) / aspect_ratio + max_height = int(max_height) + if max_height * aspect_ratio > original_width: + max_height -= 1 + + if max_height > original_height: + max_height = original_height + + if height >= max_height: + height = max_height + + height = int(round(random.uniform(height, max_height))) + width = int(round(height * aspect_ratio)) + area = width * height + + if area < min_area or area > max_area: + continue + if width > original_width or height > original_height: + continue + if area < self.min_covered * (original_width * original_height): + continue + if width == original_width and height == original_height: + return self._fallback(img) + + x = random.randint(0, original_width - width) + y = random.randint(0, original_height - height) + img = img.crop((x, y, x + width, y + height)) + return self.resize_method(img) + + return self._fallback(img) + + +class ECenterCrop: + """Crop the given PIL Image and resize it to desired size. + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + Returns: + PIL Image: Cropped image. + """ + def __init__(self, imgsize): + self.imgsize = imgsize + self.resize_method = Resize((imgsize, imgsize), + interpolation=InterpolationMode.BILINEAR) + + def __call__(self, img): + image_width, image_height = img.size + image_short = min(image_width, image_height) + + crop_size = float(self.imgsize) / (self.imgsize + 32) * image_short + + crop_height, crop_width = crop_size, crop_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + img = img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) + return self.resize_method(img) diff --git a/resnest/torch/utils.py b/resnest/torch/utils.py new file mode 100644 index 0000000..9ea1f1a --- /dev/null +++ b/resnest/torch/utils.py @@ -0,0 +1,204 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +import math +import shutil +import functools +import threading +import numpy as np +import torch +from ..utils import mkdir + +__all__ = ['accuracy', 'AverageMeter', 'LR_Scheduler', + 'torch_dist_sum', 'MixUpWrapper', 'save_checkpoint'] + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + #self.val = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + #self.val = val + self.sum += val * n + self.count += n + + @property + def avg(self): + avg = 0 if self.count == 0 else self.sum / self.count + return avg + + +def torch_dist_sum(gpu, *args): + process_group = torch.distributed.group.WORLD + tensor_args = [] + pending_res = [] + for arg in args: + if isinstance(arg, torch.Tensor): + tensor_arg = arg.clone().reshape(-1).detach().cuda(gpu) + else: + tensor_arg = torch.tensor(arg).reshape(-1).cuda(gpu) + tensor_args.append(tensor_arg) + pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True)) + for res in pending_res: + res.wait() + return tensor_args + +def get_rank(): + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + return rank + +def master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if get_rank() == 0: + return func(*args, **kwargs) + else: + return None + return wrapper + +@master_only +def master_only_print(*args): + """master-only print""" + print(*args) + +class LR_Scheduler(object): + """Learning Rate Scheduler + + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` + + Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` + + Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + + Args: + args: :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), + :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, + :attr:`args.lr_step` + + iters_per_epoch: number of iterations per epoch + """ + def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, + lr_step=0, warmup_epochs=0, quiet=False, + logger=None): + self.mode = mode + self.quiet = quiet + self.logger = logger + if not quiet: + msg = 'Using {} LR scheduler with warm-up epochs of {}!'.format(self.mode, warmup_epochs) + if self.logger: + self.logger.info(msg) + else: + master_only_print() + if mode == 'step': + assert lr_step + self.base_lr = base_lr + self.lr_step = lr_step + self.iters_per_epoch = iters_per_epoch + self.epoch = -1 + self.warmup_iters = warmup_epochs * iters_per_epoch + self.total_iters = (num_epochs - warmup_epochs) * iters_per_epoch + + def __call__(self, optimizer, i, epoch, best_pred): + T = epoch * self.iters_per_epoch + i + # warm up lr schedule + if self.warmup_iters > 0 and T < self.warmup_iters: + lr = self.base_lr * 1.0 * T / self.warmup_iters + elif self.mode == 'cos': + T = T - self.warmup_iters + lr = 0.5 * self.base_lr * (1 + math.cos(1.0 * T / self.total_iters * math.pi)) + elif self.mode == 'poly': + T = T - self.warmup_iters + lr = self.base_lr * pow((1 - 1.0 * T / self.total_iters), 0.9) + elif self.mode == 'step': + lr = self.base_lr * (0.1 ** (epoch // self.lr_step)) + else: + raise NotImplementedError + if epoch > self.epoch and (epoch == 0 or best_pred > 0.0): + if not self.quiet: + msg = '\n=>Epoch %i, learning rate = %.4f, \ + previous best = %.4f' % (epoch, lr, best_pred) + if self.logger: + self.logger.info(msg) + else: + master_only_print() + self.epoch = epoch + assert lr >= 0 + self._adjust_learning_rate(optimizer, lr) + + def _adjust_learning_rate(self, optimizer, lr): + for i in range(len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = lr + + +class MixUpWrapper(object): + def __init__(self, alpha, num_classes, dataloader, device): + self.alpha = alpha + self.dataloader = dataloader + self.num_classes = num_classes + self.device = device + + def mixup_loader(self, loader): + def mixup(alpha, num_classes, data, target): + with torch.no_grad(): + bs = data.size(0) + c = np.random.beta(alpha, alpha) + perm = torch.randperm(bs).cuda() + + md = c * data + (1-c) * data[perm, :] + mt = c * target + (1-c) * target[perm, :] + return md, mt + + for input, target in loader: + input, target = input.cuda(self.device), target.cuda(self.device) + target = torch.nn.functional.one_hot(target, self.num_classes) + i, t = mixup(self.alpha, self.num_classes, input, target) + yield i, t + + def __len__(self): + return len(self.dataloader) + + def __iter__(self): + return self.mixup_loader(self.dataloader) + +@master_only +def save_checkpoint(state, directory, is_best, filename='checkpoint.pth'): + """Saves checkpoint to disk""" + mkdir(directory) + filename = os.path.join(directory, filename) + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, os.path.join(directory, 'model_best.pth')) + diff --git a/scripts/gluon/train.py b/scripts/gluon/train.py index ccbf87a..97ede40 100644 --- a/scripts/gluon/train.py +++ b/scripts/gluon/train.py @@ -42,7 +42,7 @@ from resnest.gluon import get_model from resnest.utils import mkdir -from resnest.transforms import ERandomCrop, ECenterCrop +from resnest.gluon.transforms import ERandomCrop, ECenterCrop from torchvision.transforms import transforms as pth_transforms try: diff --git a/scripts/gluon/verify.py b/scripts/gluon/verify.py index 5e7f7e2..f79b59c 100644 --- a/scripts/gluon/verify.py +++ b/scripts/gluon/verify.py @@ -117,7 +117,7 @@ def test(network, ctx, val_data, batch_fn): resize = int(math.ceil(input_size/crop_ratio)) if input_size >= 320: - from resnest.transforms import ECenterCrop + from resnest.gluon.transforms import ECenterCrop from resnest.gluon.data_utils import ToPIL, ToNDArray transform_test = transforms.Compose([ ToPIL(), diff --git a/scripts/torch/train.py b/scripts/torch/train.py new file mode 100644 index 0000000..9896fe7 --- /dev/null +++ b/scripts/torch/train.py @@ -0,0 +1,291 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +import time +import logging +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel + +from resnest.utils import mkdir +from resnest.torch.config import get_cfg +from resnest.torch.models.build import get_model +from resnest.torch.datasets import get_dataset +from resnest.torch.transforms import get_transform +from resnest.torch.loss import get_criterion +from resnest.torch.utils import (save_checkpoint, accuracy, + AverageMeter, LR_Scheduler, torch_dist_sum) + +logger = logging.getLogger('train') +logger.setLevel(logging.INFO) + +class Options(): + def __init__(self): + # data settings + parser = argparse.ArgumentParser(description='ResNeSt Training') + parser.add_argument('--config-file', type=str, default=None, + help='training configs') + parser.add_argument('--outdir', type=str, default='output', + help='output directory') + # checking point + parser.add_argument('--resume', type=str, default=None, + help='put the path to resuming file if needed') + # distributed + parser.add_argument('--world-size', default=1, type=int, + help='number of nodes for distributed training') + parser.add_argument('--rank', default=0, type=int, + help='node rank for distributed training') + parser.add_argument('--dist-url', default='tcp://localhost:23456', type=str, + help='url used to set up distributed training') + parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') + # evaluation option + parser.add_argument('--eval-only', action='store_true', default= False, + help='evaluating') + parser.add_argument('--export', type=str, default=None, + help='put the path to resuming file if needed') + self.parser = parser + + def parse(self): + args = self.parser.parse_args() + return args + +def main(): + args = Options().parse() + ngpus_per_node = torch.cuda.device_count() + args.world_size = ngpus_per_node * args.world_size + + # load config + cfg = get_cfg() + cfg.merge_from_file(args.config_file) + + cfg.OPTIMIZER.LR = cfg.OPTIMIZER.LR * args.world_size + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, cfg)) + +# global variable +best_pred = 0.0 +acclist_train = [] +acclist_val = [] + +def main_worker(gpu, ngpus_per_node, args, cfg): + args.gpu = gpu + args.rank = args.rank * ngpus_per_node + gpu + logger.info(f'rank: {args.rank} / {args.world_size}') + dist.init_process_group(backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank) + torch.cuda.set_device(args.gpu) + if args.gpu == 0: + mkdir(args.outdir) + fh = logging.FileHandler(os.path.join(args.outdir, 'log.txt')) + fh.setLevel(logging.INFO) + logger.addHandler(fh) + logger.info(args) + + # init the global + global best_pred, acclist_train, acclist_val + + # seed + torch.manual_seed(cfg.SEED) + torch.cuda.manual_seed(cfg.SEED) + + # init dataloader + transform_train, transform_val = get_transform(cfg.DATA.DATASET)( + cfg.DATA.BASE_SIZE, cfg.DATA.CROP_SIZE, cfg.DATA.RAND_AUG) + trainset = get_dataset(cfg.DATA.DATASET)(root=cfg.DATA.ROOT, + transform=transform_train, + train=True, + download=True) + valset = get_dataset(cfg.DATA.DATASET)(root=cfg.DATA.ROOT, + transform=transform_val, + train=False, + download=True) + + train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) + train_loader = torch.utils.data.DataLoader( + trainset, batch_size=cfg.TRAINING.BATCH_SIZE, shuffle=False, + num_workers=cfg.TRAINING.WORKERS, pin_memory=True, + sampler=train_sampler) + + val_sampler = torch.utils.data.distributed.DistributedSampler(valset, shuffle=False) + val_loader = torch.utils.data.DataLoader( + valset, batch_size=cfg.TRAINING.TEST_BATCH_SIZE, shuffle=False, + num_workers=cfg.TRAINING.WORKERS, pin_memory=True, + sampler=val_sampler) + + # init the model + model_kwargs = {} + if cfg.MODEL.FINAL_DROP > 0.0: + model_kwargs['final_drop'] = cfg.MODEL.FINAL_DROP + + if cfg.TRAINING.LAST_GAMMA: + model_kwargs['last_gamma'] = True + + model = get_model(cfg.MODEL.NAME)(**model_kwargs) + + if args.gpu == 0: + logger.info(model) + + criterion, train_loader = get_criterion(cfg, train_loader, args.gpu) + + model.cuda(args.gpu) + criterion.cuda(args.gpu) + model = DistributedDataParallel(model, device_ids=[args.gpu]) + + # criterion and optimizer + if cfg.OPTIMIZER.DISABLE_BN_WD: + parameters = model.named_parameters() + param_dict = {} + for k, v in parameters: + param_dict[k] = v + bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)] + rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)] + if args.gpu == 0: + logger.info(" Weight decay NOT applied to BN parameters ") + logger.info(f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}') + optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0 }, + {'params': rest_params, 'weight_decay': cfg.OPTIMIZER.WEIGHT_DECAY}], + lr=cfg.OPTIMIZER.LR, + momentum=cfg.OPTIMIZER.MOMENTUM, + weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) + else: + optimizer = torch.optim.SGD(model.parameters(), + lr=cfg.OPTIMIZER.LR, + momentum=cfg.OPTIMIZER.MOMENTUM, + weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY) + # check point + if args.resume is not None: + if os.path.isfile(args.resume): + if args.gpu == 0: + logger.info(f"=> loading checkpoint '{args.resume}'") + checkpoint = torch.load(args.resume) + cfg.TRAINING.START_EPOCHS = checkpoint['epoch'] + 1 if cfg.TRAINING.START_EPOCHS == 0 \ + else cfg.TRAINING.START_EPOCHS + best_pred = checkpoint['best_pred'] + acclist_train = checkpoint['acclist_train'] + acclist_val = checkpoint['acclist_val'] + model.module.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + if args.gpu == 0: + logger.info(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") + else: + raise RuntimeError (f"=> no resume checkpoint found at '{args.resume}'") + + scheduler = LR_Scheduler(cfg.OPTIMIZER.LR_SCHEDULER, + base_lr=cfg.OPTIMIZER.LR, + num_epochs=cfg.TRAINING.EPOCHS, + iters_per_epoch=len(train_loader), + warmup_epochs=cfg.OPTIMIZER.WARMUP_EPOCHS) + def train(epoch): + train_sampler.set_epoch(epoch) + model.train() + losses = AverageMeter() + top1 = AverageMeter() + global best_pred, acclist_train + for batch_idx, (data, target) in enumerate(train_loader): + scheduler(optimizer, batch_idx, epoch, best_pred) + if not cfg.DATA.MIXUP: + data, target = data.cuda(args.gpu), target.cuda(args.gpu) + optimizer.zero_grad() + output = model(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + if not cfg.DATA.MIXUP: + acc1 = accuracy(output, target, topk=(1,)) + top1.update(acc1[0], data.size(0)) + + losses.update(loss.item(), data.size(0)) + if batch_idx % 100 == 0 and args.gpu == 0: + if cfg.DATA.MIXUP: + logger.info('Batch: %d| Loss: %.3f'%(batch_idx, losses.avg)) + else: + logger.info('Batch: %d| Loss: %.3f | Top1: %.3f'%(batch_idx, losses.avg, top1.avg)) + + acclist_train += [top1.avg] + + def validate(epoch): + model.eval() + top1 = AverageMeter() + top5 = AverageMeter() + global best_pred, acclist_train, acclist_val + is_best = False + for batch_idx, (data, target) in enumerate(val_loader): + data, target = data.cuda(args.gpu), target.cuda(args.gpu) + with torch.no_grad(): + output = model(data) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], data.size(0)) + top5.update(acc5[0], data.size(0)) + + # sum all + sum1, cnt1, sum5, cnt5 = torch_dist_sum(args.gpu, top1.sum, top1.count, top5.sum, top5.count) + + if args.gpu == 0: + top1_acc = sum(sum1) / sum(cnt1) + top5_acc = sum(sum5) / sum(cnt5) + logger.info('Validation: Top1: %.3f | Top5: %.3f'%(top1_acc, top5_acc)) + if args.eval_only: + return + + # save checkpoint + acclist_val += [top1_acc] + if top1_acc > best_pred: + best_pred = top1_acc + is_best = True + save_checkpoint({ + 'epoch': epoch, + 'state_dict': model.module.state_dict(), + 'optimizer': optimizer.state_dict(), + 'best_pred': best_pred, + 'acclist_train':acclist_train, + 'acclist_val':acclist_val, + }, + directory=args.outdir, + is_best=False, + filename=f'checkpoint_{epoch}.pth') + + if args.export: + if args.gpu == 0: + torch.save(model.module.state_dict(), args.export + '.pth') + return + + if args.eval_only: + validate(cfg.TRAINING.START_EPOCHS) + return + + for epoch in range(cfg.TRAINING.START_EPOCHS, cfg.TRAINING.EPOCHS): + tic = time.time() + train(epoch) + if epoch % 10 == 0 or epoch == cfg.TRAINING.EPOCHS - 1: + validate(epoch) + elapsed = time.time() - tic + if args.gpu == 0: + logger.info(f'Epoch: {epoch}, Time cost: {elapsed}') + + if args.gpu == 0: + save_checkpoint({ + 'epoch': cfg.TRAINING.EPOCHS - 1, + 'state_dict': model.module.state_dict(), + 'optimizer': optimizer.state_dict(), + 'best_pred': best_pred, + 'acclist_train':acclist_train, + 'acclist_val':acclist_val, + }, + directory=args.outdir, + is_best=False, + filename='model_final.pth') + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 91b2e93..eada4be 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def create_version_file(): 'Pillow', 'scipy', 'requests', + 'fvcore', ] if __name__ == '__main__': diff --git a/tests/test_radix_major.py b/tests/test_radix_major.py index 00bfbbb..8e1cee3 100644 --- a/tests/test_radix_major.py +++ b/tests/test_radix_major.py @@ -5,7 +5,7 @@ from torch.nn import Conv2d, Module, Linear, BatchNorm2d, ReLU from torch.nn.modules.utils import _pair -from resnest.torch.splat import SplAtConv2d +from resnest.torch.models.splat import SplAtConv2d class RadixMajorNaiveImp(Module): """Split-Attention Conv2d