-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
60 lines (46 loc) · 1.9 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys
import yaml
import typing
import argparse
def latest_n_checkpoints(folder, *, prefix='checkpoint', all_but=False, last_n=1):
dirs = [d for d in os.listdir(folder) if \
d.startswith(prefix) and not os.path.isfile(d)]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
if len(dirs) == 0:
return [ ]
if not all_but:
return dirs[-last_n:]
else:
return dirs[:-last_n]
def _remove_defaults(args: dict, raw_args: list):
modified_args = { }
for k, v in args.items():
if ('--' + k) not in raw_args:
modified_args[k] = v
return modified_args
def yaml_interface(script_file_path):
def _decorator(create_parser_fn: typing.Callable):
def _get_args_fn():
script_file_name = os.path.basename(script_file_path)
ret: typing.Tuple[argparse.ArgumentParser, typing.Callable] = create_parser_fn()
parser, validation_fn = ret
parser.add_argument(
'--config',
type=str,
required=False,
help='A yaml config file as a replacement for command line arguments'
)
args = parser.parse_args()
# deleting the original args since they are not needed to go to the validation_fn
config_file = args.config; del args.config
combined_args = vars(args)
if config_file:
with open(config_file, 'r') as f:
yaml_dict: dict = yaml.unsafe_load(f)
config_from_file = yaml_dict.get('common', { })
config_from_file.update(yaml_dict.get(script_file_name, { }))
combined_args.update(_remove_defaults(config_from_file, sys.argv[1:]))
return validation_fn(argparse.Namespace(**combined_args))
return _get_args_fn
return _decorator