Skip to content

Commit

Permalink
create models from timm
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Nov 28, 2021
1 parent 899e7b9 commit 518aa17
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 31 deletions.
3 changes: 1 addition & 2 deletions infer.py → benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ def infer(self):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--model', '-m', type=str)
parser.add_argument('--torch', action='store_true')
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--amp', action='store_true')

args = parser.parse_args()
print(args)

model = create_model(args.model, torch=args.torch)
model = create_model(args.model)

input = torch.randn(args.batch_size, 3, 224, 224)

Expand Down
28 changes: 25 additions & 3 deletions cvm/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
from nvidia.dali.pipeline import pipeline_def
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy

try:
import timm
has_timm = True
except ImportError:
has_timm = False

__all__ = [
'create_model', 'create_optimizer', 'create_scheduler',
'create_transforms', 'create_dataset', 'create_loader'
Expand Down Expand Up @@ -126,19 +132,35 @@ def create_dali_pipeline(
def create_model(
name: str,
pretrained: bool = False,
torch: bool = False,
cuda: bool = True,
sync_bn: bool = False,
distributed: bool = False,
local_rank: int = 0,
**kwargs
):
if torch:
if name.startswith('torch/'):
name = name.replace('torch/', '')
model = torchvision.models.__dict__[name](pretrained=pretrained)
elif name.startswith('timm/'):
assert has_timm, 'Please install timm first.'
name = name.replace('timm/', '')
model = timm.create_model(
name,
pretrained=pretrained,
num_classes=kwargs.get('num_classes', 1000),
drop_rate=kwargs.get('dropout_rate', 0.0),
drop_path_rate=kwargs.get('drop_path_rate', None),
drop_block_rate=kwargs.get('drop_block', None),
bn_momentum=kwargs.get('bn_momentum', None),
bn_eps=kwargs.get('bn_eps', None),
scriptable=kwargs.get('scriptable', False),
checkpoint_path=kwargs.get('initial_checkpoint', None),
)
else:
if 'bn_eps' in kwargs and kwargs['bn_eps'] and 'bn_momentum' in kwargs and kwargs['bn_momentum']:
with cvm.models.core.blocks.normalizer(partial(nn.BatchNorm2d, eps=kwargs['bn_eps'], momentum=kwargs['bn_momentum'])):
model = cvm.models.__dict__[name](pretrained=pretrained, **kwargs)
model = cvm.models.__dict__[name](
pretrained=pretrained, **kwargs)
model = cvm.models.__dict__[name](pretrained=pretrained, **kwargs)

if cuda:
Expand Down
43 changes: 35 additions & 8 deletions cvm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
from cvm import models
import torch.distributed as dist

try:
import timm
has_timm = True
except ImportError:
has_timm = False

__all__ = [
'Benchmark', 'env_info', 'manual_seed',
'named_layers', 'accuracy', 'AverageMeter',
Expand Down Expand Up @@ -167,15 +173,36 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


def list_models(torch: bool = False):
if torch:
return sorted(name for name in torchvision.models.__dict__
if name.islower() and not name.startswith("__")
and callable(torchvision.models.__dict__[name]))
def _filter_models(name_list, prefix='', sort=False):
models = [prefix + name for name in name_list
if name.islower() and not name.startswith("__")
and callable(name_list[name])]
return models if not sort else sorted(models)


def list_models(lib: str = 'all'):
assert lib in ['all', 'cvm', 'torch', 'timm'], f'Unknown library {lib}.'

return sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
if lib == 'all':
cvm_models = _filter_models(torchvision.models.__dict__, sort=True)
torch_models = _filter_models(models.__dict__, 'torch/', True)

timm_models = [
'timm/' + name for name in timm.list_models()
] if has_timm else []
return cvm_models + torch_models + timm_models

elif lib == 'torch':
return _filter_models(
torchvision.models.__dict__,
prefix='torch/',
sort=True
)
elif lib == 'timm':
assert has_timm, 'Please install timm first.'
return ['timm/' + name for name in timm.list_models()]
else:
return _filter_models(models.__dict__, sort=True)


def list_datasets():
Expand Down
8 changes: 3 additions & 5 deletions flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ def print_model(model, table: bool = False):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--model', '-m', type=str)
parser.add_argument('--torch', action='store_true')
parser.add_argument('--table', action='store_true')
parser.add_argument('--models', action='store_true')
parser.add_argument('--list-models', type=str, default=None)
parser.add_argument('--num-classes', type=int, default=1000)
parser.add_argument('--image-size', type=int, default=224)

Expand All @@ -27,14 +26,13 @@ def print_model(model, table: bool = False):

thumbnail = True if args.image_size < 100 else False

if args.models:
print(json.dumps(list_models(args.torch), indent=4))
if args.list_models:
print(json.dumps(list_models(args.list_models), indent=4))
else:
print_model(
create_model(
args.model,
thumbnail=thumbnail,
torch=args.torch,
num_classes=args.num_classes,
cuda=False,
),
Expand Down
3 changes: 1 addition & 2 deletions info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--model', '-m', type=str)
parser.add_argument('--torch', action='store_true')

args = parser.parse_args()

model = create_model(args.model, torch=args.torch, cuda=False)
model = create_model(args.model, cuda=False)

summary(
model,
Expand Down
3 changes: 1 addition & 2 deletions profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
parser.add_argument('--model', type=str, default='micronet_b1_0')
parser.add_argument('--batch-size', type=int, default=64, metavar='N')
parser.add_argument('--amp', action='store_true')
parser.add_argument('--torch', action='store_true')
args = parser.parse_args()

model = create_model(args.model, torch=args.torch)
model = create_model(args.model)
model.eval()

images = torch.randn([args.batch_size, 3, 224, 224]).cuda()
Expand Down
7 changes: 2 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def parse_args():
parser.add_argument('--val-crop-size', type=int, default=224)

# model
parser.add_argument('--model', type=str, default='muxnet_v2', choices=list_models() + list_models(True),
help='type of model to use. (default: muxnet_v2)')
parser.add_argument('--torch', action='store_true',
help='use torchvision models. (default: false)')
parser.add_argument('--model', type=str, default='resnet18_v1', choices=list_models(),
help='type of model to use. (default: resnet18_v1)')
parser.add_argument('--pretrained', action='store_true',
help='use pre-trained model. (default: false)')
parser.add_argument('--model-path', type=str, default=None)
Expand Down Expand Up @@ -223,7 +221,6 @@ def validate(val_loader, model, criterion):

model = create_model(
args.model,
torch=args.torch,
num_classes=args.num_classes,
dropout_rate=args.dropout_rate,
drop_path_rate=args.drop_path_rate,
Expand Down
5 changes: 1 addition & 4 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@ def parse_args():
help='path to the ImageNet dataset.')
parser.add_argument('--data-dir', type=str, default='/datasets/ILSVRC2012',
help='path to the ImageNet dataset.')
parser.add_argument('--model', type=str, default='mobilenet_v1_x1_0', choices=list_models() + list_models(True),
parser.add_argument('--model', type=str, default='mobilenet_v1_x1_0', choices=list_models(),
help='type of model to use. (default: mobilenet_v1_x1_0)')
parser.add_argument('--torch', action='store_true',
help='use torchvision models. (default: false)')
parser.add_argument('--pretrained', action='store_true',
help='use pre-trained model. (default: false)')
parser.add_argument('--model-path', type=str, default=None)
Expand Down Expand Up @@ -65,7 +63,6 @@ def validate(val_loader, model, args):
model = create_model(
args.model,
pretrained=args.pretrained,
torch=args.torch,
pth=args.model_path
)

Expand Down

0 comments on commit 518aa17

Please sign in to comment.