Skip to content

Commit

Permalink
Add bulk_runner script and updates to benchmark.py and validate.py fo…
Browse files Browse the repository at this point in the history
…r better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
  • Loading branch information
rwightman committed Jul 19, 2022
1 parent 4547920 commit 0dbd935
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 25 deletions.
33 changes: 16 additions & 17 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry

has_apex = False
try:
Expand Down Expand Up @@ -506,34 +506,31 @@ def run(self):
return results


def decay_batch_exp(batch_size, factor=0.5, divisor=16):
out_batch_size = batch_size * factor
if out_batch_size > divisor:
out_batch_size = (out_batch_size + 1) // divisor * divisor
else:
out_batch_size = batch_size - 1
return max(0, int(out_batch_size))


def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
def _try_run(
model_name,
bench_fn,
bench_kwargs,
initial_batch_size,
no_batch_size_retry=False
):
batch_size = initial_batch_size
results = dict()
error_str = 'Unknown'
while batch_size >= 1:
torch.cuda.empty_cache()
while batch_size:
try:
torch.cuda.empty_cache()
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
results = bench.run()
return results
except RuntimeError as e:
error_str = str(e)
if 'channels_last' in error_str:
_logger.error(f'{model_name} not supported in channels_last, skipping.')
break
_logger.error(f'"{error_str}" while running benchmark.')
if not check_batch_size_retry(error_str):
_logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.')
break
if no_batch_size_retry:
break
batch_size = decay_batch_exp(batch_size)
batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
return results
Expand Down Expand Up @@ -586,6 +583,8 @@ def benchmark(args):
if prefix and 'error' not in run_results:
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
model_results.update(run_results)
if 'error' in run_results:
break
if 'error' not in model_results:
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
model_results.setdefault('param_count', param_count)
Expand Down
184 changes: 184 additions & 0 deletions bulk_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
""" Bulk Model Script Runner
Run validation or benchmark script in separate process for each model
Benchmark all 'vit*' models:
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
Validate all models:
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import sys
import csv
import json
import subprocess
import time
from typing import Callable, List, Tuple, Union


from timm.models import is_model, list_models


parser = argparse.ArgumentParser(description='Per-model process launcher')

# model and results args
parser.add_argument(
'--model-list', metavar='NAME', default='',
help='txt file based list of model names to benchmark')
parser.add_argument(
'--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument(
'--sort-key', default='', type=str, metavar='COL',
help='Specify sort key for results csv')
parser.add_argument(
"--pretrained", action='store_true',
help="only run models with pretrained weights")

parser.add_argument(
"--delay",
type=float,
default=0,
help="Interval, in seconds, to delay between model invocations.",
)
parser.add_argument(
"--start_method", type=str, default="spawn", choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method to use when creating workers.",
)
parser.add_argument(
"--no_python",
help="Skip prepending the script with 'python' - just execute it directly. Useful "
"when the script is not a Python script.",
)
parser.add_argument(
"-m",
"--module",
help="Change each process to interpret the launch script as a Python module, executing "
"with the same behavior as 'python -m'.",
)

# positional
parser.add_argument(
"script", type=str,
help="Full path to the program/script to be launched for each model config.",
)
parser.add_argument("script_args", nargs=argparse.REMAINDER)


def cmd_from_args(args) -> Tuple[Union[Callable, str], List[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
with_python = not args.no_python
cmd: Union[Callable, str]
cmd_args = []
if with_python:
cmd = os.getenv("PYTHON_EXEC", sys.executable)
cmd_args.append("-u")
if args.module:
cmd_args.append("-m")
cmd_args.append(args.script)
else:
if args.module:
raise ValueError(
"Don't use both the '--no_python' flag"
" and the '--module' flag at the same time."
)
cmd = args.script
cmd_args.extend(args.script_args)

return cmd, cmd_args


def main():
args = parser.parse_args()
cmd, cmd_args = cmd_from_args(args)

model_cfgs = []
model_names = []
if args.model_list == 'all':
# NOTE should make this config, for validation / benchmark runs the focus is 1k models,
# so we filter out 21/22k and some other unusable heads. This will change in the future...
exclude_model_filters = ['*in21k', '*in22k', '*dino', '*_22k']
model_names = list_models(
pretrained=args.pretrained, # only include models w/ pretrained checkpoints if set
exclude_filters=exclude_model_filters
)
model_cfgs = [(n, None) for n in model_names]
elif not is_model(args.model_list):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model_list)
model_cfgs = [(n, None) for n in model_names]

if not model_cfgs and os.path.exists(args.model_list):
with open(args.model_list) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names]

if len(model_cfgs):
results_file = args.results_file or './results.csv'
results = []
errors = []
print('Running script on these models: {}'.format(', '.join(model_names)))
if not args.sort_key:
if 'benchmark' in args.script:
if any(['train' in a for a in args.script_args]):
sort_key = 'train_samples_per_sec'
else:
sort_key = 'infer_samples_per_sec'
else:
sort_key = 'top1'
else:
sort_key = args.sort_key
print(f'Script: {args.script}, Args: {args.script_args}, Sort key: {sort_key}')

try:
for m, _ in model_cfgs:
if not m:
continue
args_str = (cmd, *[str(e) for e in cmd_args], '--model', m)
try:
o = subprocess.check_output(args=args_str).decode('utf-8').split('--result')[-1]
r = json.loads(o)
results.append(r)
except Exception as e:
# FIXME batch_size retry loop is currently done in either validation.py or benchmark.py
# for further robustness (but more overhead), we may want to manage that by looping here...
errors.append(dict(model=m, error=str(e)))
if args.delay:
time.sleep(args.delay)
except KeyboardInterrupt as e:
pass

errors.extend(list(filter(lambda x: 'error' in x, results)))
if errors:
print(f'{len(errors)} models had errors during run.')
for e in errors:
print(f"\t {e['model']} ({e.get('error', 'Unknown')})")
results = list(filter(lambda x: 'error' not in x, results))

no_sortkey = list(filter(lambda x: sort_key not in x, results))
if no_sortkey:
print(f'{len(no_sortkey)} results missing sort key, skipping sort.')
else:
results = sorted(results, key=lambda x: x[sort_key], reverse=True)

if len(results):
print(f'{len(results)} models run successfully. Saving results to {results_file}.')
write_results(results_file, results)


def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()


if __name__ == '__main__':
main()
1 change: 1 addition & 0 deletions timm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .decay_batch import decay_batch_step, check_batch_size_retry
from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
Expand Down
43 changes: 43 additions & 0 deletions timm/utils/decay_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
""" Batch size decay and retry helpers.
Copyright 2022 Ross Wightman
"""
import math


def decay_batch_step(batch_size, num_intra_steps=2, no_odd=False):
""" power of two batch-size decay with intra steps
Decay by stepping between powers of 2:
* determine power-of-2 floor of current batch size (base batch size)
* divide above value by num_intra_steps to determine step size
* floor batch_size to nearest multiple of step_size (from base batch size)
Examples:
num_steps == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1
num_steps (no_odd=True) == 4 --> 64, 56, 48, 40, 32, 28, 24, 20, 16, 14, 12, 10, 8, 6, 4, 2
num_steps == 2 --> 64, 48, 32, 24, 16, 12, 8, 6, 4, 3, 2, 1
num_steps == 1 --> 64, 32, 16, 8, 4, 2, 1
"""
if batch_size <= 1:
# return 0 for stopping value so easy to use in loop
return 0
base_batch_size = int(2 ** (math.log(batch_size - 1) // math.log(2)))
step_size = max(base_batch_size // num_intra_steps, 1)
batch_size = base_batch_size + ((batch_size - base_batch_size - 1) // step_size) * step_size
if no_odd and batch_size % 2:
batch_size -= 1
return batch_size


def check_batch_size_retry(error_str):
""" check failure error string for conditions where batch decay retry should not be attempted
"""
error_str = error_str.lower()
if 'required rank' in error_str:
# Errors involving phrase 'required rank' typically happen when a conv is used that's
# not compatible with channels_last memory format.
return False
if 'illegal' in error_str:
# 'Illegal memory access' errors in CUDA typically leave process in unusable state
return False
return True
23 changes: 15 additions & 8 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry

has_apex = False
try:
Expand Down Expand Up @@ -122,6 +123,8 @@
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')


def validate(args):
Expand Down Expand Up @@ -303,18 +306,19 @@ def _try_run(args, initial_batch_size):
batch_size = initial_batch_size
results = OrderedDict()
error_str = 'Unknown'
while batch_size >= 1:
args.batch_size = batch_size
torch.cuda.empty_cache()
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
torch.cuda.empty_cache()
results = validate(args)
return results
except RuntimeError as e:
error_str = str(e)
if 'channels_last' in error_str:
_logger.error(f'"{error_str}" while running validation.')
if not check_batch_size_retry(error_str):
break
_logger.warning(f'"{error_str}" while running validation. Reducing batch size to {batch_size} for retry.')
batch_size = batch_size // 2
batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
_logger.error(f'{args.model} failed to validate ({error_str}).')
return results
Expand Down Expand Up @@ -368,7 +372,10 @@ def main():
if len(results):
write_results(results_file, results)
else:
results = validate(args)
if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')

Expand Down

0 comments on commit 0dbd935

Please sign in to comment.