From 518aa174c4af071fa5811d310e89fd548ef4cb4b Mon Sep 17 00:00:00 2001 From: ffiirree Date: Sun, 28 Nov 2021 16:12:52 +0800 Subject: [PATCH] create models from timm --- infer.py => benchmark.py | 3 +-- cvm/utils/factory.py | 28 +++++++++++++++++++++++--- cvm/utils/utils.py | 43 ++++++++++++++++++++++++++++++++-------- flops.py | 8 +++----- info.py | 3 +-- profile.py | 3 +-- train.py | 7 ++----- validate.py | 5 +---- 8 files changed, 69 insertions(+), 31 deletions(-) rename infer.py => benchmark.py (93%) diff --git a/infer.py b/benchmark.py similarity index 93% rename from infer.py rename to benchmark.py index 5fa9fc8..b829bc6 100644 --- a/infer.py +++ b/benchmark.py @@ -32,14 +32,13 @@ def infer(self): if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--model', '-m', type=str) - parser.add_argument('--torch', action='store_true') parser.add_argument('--batch-size', type=int, default=16) parser.add_argument('--amp', action='store_true') args = parser.parse_args() print(args) - model = create_model(args.model, torch=args.torch) + model = create_model(args.model) input = torch.randn(args.batch_size, 3, 224, 224) diff --git a/cvm/utils/factory.py b/cvm/utils/factory.py index 914f1de..103ab7e 100644 --- a/cvm/utils/factory.py +++ b/cvm/utils/factory.py @@ -19,6 +19,12 @@ from nvidia.dali.pipeline import pipeline_def from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy +try: + import timm + has_timm = True +except ImportError: + has_timm = False + __all__ = [ 'create_model', 'create_optimizer', 'create_scheduler', 'create_transforms', 'create_dataset', 'create_loader' @@ -126,19 +132,35 @@ def create_dali_pipeline( def create_model( name: str, pretrained: bool = False, - torch: bool = False, cuda: bool = True, sync_bn: bool = False, distributed: bool = False, local_rank: int = 0, **kwargs ): - if torch: + if name.startswith('torch/'): + name = name.replace('torch/', '') model = torchvision.models.__dict__[name](pretrained=pretrained) + elif name.startswith('timm/'): + assert has_timm, 'Please install timm first.' + name = name.replace('timm/', '') + model = timm.create_model( + name, + pretrained=pretrained, + num_classes=kwargs.get('num_classes', 1000), + drop_rate=kwargs.get('dropout_rate', 0.0), + drop_path_rate=kwargs.get('drop_path_rate', None), + drop_block_rate=kwargs.get('drop_block', None), + bn_momentum=kwargs.get('bn_momentum', None), + bn_eps=kwargs.get('bn_eps', None), + scriptable=kwargs.get('scriptable', False), + checkpoint_path=kwargs.get('initial_checkpoint', None), + ) else: if 'bn_eps' in kwargs and kwargs['bn_eps'] and 'bn_momentum' in kwargs and kwargs['bn_momentum']: with cvm.models.core.blocks.normalizer(partial(nn.BatchNorm2d, eps=kwargs['bn_eps'], momentum=kwargs['bn_momentum'])): - model = cvm.models.__dict__[name](pretrained=pretrained, **kwargs) + model = cvm.models.__dict__[name]( + pretrained=pretrained, **kwargs) model = cvm.models.__dict__[name](pretrained=pretrained, **kwargs) if cuda: diff --git a/cvm/utils/utils.py b/cvm/utils/utils.py index 17befb4..d1b8de7 100644 --- a/cvm/utils/utils.py +++ b/cvm/utils/utils.py @@ -10,6 +10,12 @@ from cvm import models import torch.distributed as dist +try: + import timm + has_timm = True +except ImportError: + has_timm = False + __all__ = [ 'Benchmark', 'env_info', 'manual_seed', 'named_layers', 'accuracy', 'AverageMeter', @@ -167,15 +173,36 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def list_models(torch: bool = False): - if torch: - return sorted(name for name in torchvision.models.__dict__ - if name.islower() and not name.startswith("__") - and callable(torchvision.models.__dict__[name])) +def _filter_models(name_list, prefix='', sort=False): + models = [prefix + name for name in name_list + if name.islower() and not name.startswith("__") + and callable(name_list[name])] + return models if not sort else sorted(models) + + +def list_models(lib: str = 'all'): + assert lib in ['all', 'cvm', 'torch', 'timm'], f'Unknown library {lib}.' - return sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) + if lib == 'all': + cvm_models = _filter_models(torchvision.models.__dict__, sort=True) + torch_models = _filter_models(models.__dict__, 'torch/', True) + + timm_models = [ + 'timm/' + name for name in timm.list_models() + ] if has_timm else [] + return cvm_models + torch_models + timm_models + + elif lib == 'torch': + return _filter_models( + torchvision.models.__dict__, + prefix='torch/', + sort=True + ) + elif lib == 'timm': + assert has_timm, 'Please install timm first.' + return ['timm/' + name for name in timm.list_models()] + else: + return _filter_models(models.__dict__, sort=True) def list_datasets(): diff --git a/flops.py b/flops.py index c5ee9f6..2e9b2d6 100755 --- a/flops.py +++ b/flops.py @@ -15,9 +15,8 @@ def print_model(model, table: bool = False): if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--model', '-m', type=str) - parser.add_argument('--torch', action='store_true') parser.add_argument('--table', action='store_true') - parser.add_argument('--models', action='store_true') + parser.add_argument('--list-models', type=str, default=None) parser.add_argument('--num-classes', type=int, default=1000) parser.add_argument('--image-size', type=int, default=224) @@ -27,14 +26,13 @@ def print_model(model, table: bool = False): thumbnail = True if args.image_size < 100 else False - if args.models: - print(json.dumps(list_models(args.torch), indent=4)) + if args.list_models: + print(json.dumps(list_models(args.list_models), indent=4)) else: print_model( create_model( args.model, thumbnail=thumbnail, - torch=args.torch, num_classes=args.num_classes, cuda=False, ), diff --git a/info.py b/info.py index 02f8461..63eb4cf 100644 --- a/info.py +++ b/info.py @@ -5,11 +5,10 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('--model', '-m', type=str) - parser.add_argument('--torch', action='store_true') args = parser.parse_args() - model = create_model(args.model, torch=args.torch, cuda=False) + model = create_model(args.model, cuda=False) summary( model, diff --git a/profile.py b/profile.py index d0ec519..c858400 100644 --- a/profile.py +++ b/profile.py @@ -9,10 +9,9 @@ parser.add_argument('--model', type=str, default='micronet_b1_0') parser.add_argument('--batch-size', type=int, default=64, metavar='N') parser.add_argument('--amp', action='store_true') - parser.add_argument('--torch', action='store_true') args = parser.parse_args() - model = create_model(args.model, torch=args.torch) + model = create_model(args.model) model.eval() images = torch.randn([args.batch_size, 3, 224, 224]).cuda() diff --git a/train.py b/train.py index 80b29fe..31d942e 100644 --- a/train.py +++ b/train.py @@ -28,10 +28,8 @@ def parse_args(): parser.add_argument('--val-crop-size', type=int, default=224) # model - parser.add_argument('--model', type=str, default='muxnet_v2', choices=list_models() + list_models(True), - help='type of model to use. (default: muxnet_v2)') - parser.add_argument('--torch', action='store_true', - help='use torchvision models. (default: false)') + parser.add_argument('--model', type=str, default='resnet18_v1', choices=list_models(), + help='type of model to use. (default: resnet18_v1)') parser.add_argument('--pretrained', action='store_true', help='use pre-trained model. (default: false)') parser.add_argument('--model-path', type=str, default=None) @@ -223,7 +221,6 @@ def validate(val_loader, model, criterion): model = create_model( args.model, - torch=args.torch, num_classes=args.num_classes, dropout_rate=args.dropout_rate, drop_path_rate=args.drop_path_rate, diff --git a/validate.py b/validate.py index 15ef51e..c8908e3 100644 --- a/validate.py +++ b/validate.py @@ -11,10 +11,8 @@ def parse_args(): help='path to the ImageNet dataset.') parser.add_argument('--data-dir', type=str, default='/datasets/ILSVRC2012', help='path to the ImageNet dataset.') - parser.add_argument('--model', type=str, default='mobilenet_v1_x1_0', choices=list_models() + list_models(True), + parser.add_argument('--model', type=str, default='mobilenet_v1_x1_0', choices=list_models(), help='type of model to use. (default: mobilenet_v1_x1_0)') - parser.add_argument('--torch', action='store_true', - help='use torchvision models. (default: false)') parser.add_argument('--pretrained', action='store_true', help='use pre-trained model. (default: false)') parser.add_argument('--model-path', type=str, default=None) @@ -65,7 +63,6 @@ def validate(val_loader, model, args): model = create_model( args.model, pretrained=args.pretrained, - torch=args.torch, pth=args.model_path )