Skip to content

Commit

Permalink
Change to remake the reweighting info by default.
Browse files Browse the repository at this point in the history
Can be disabled (i.e., using existing reweighting info) with `--no-remake-weights`.
  • Loading branch information
hqucms committed Jul 2, 2022
1 parent fe8bf87 commit 9254354
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.3.5',
version='0.3.6',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
3 changes: 3 additions & 0 deletions weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
help='load the whole dataset (and perform the preprocessing) only once and keep it in memory for the entire run')
parser.add_argument('--train-val-split', type=float, default=0.8,
help='training/validation split fraction')
parser.add_argument('--no-remake-weights', action='store_true', default=False,
help='do not remake weights for sampling (reweighting), use existing ones in the previous auto-generated data config YAML file')
parser.add_argument('--demo', action='store_true', default=False,
help='quickly test the setup by running over only a small number of events')
parser.add_argument('--lr-finder', type=str, default=None,
Expand Down Expand Up @@ -221,6 +223,7 @@ def train_load(args):
raise RuntimeError('Must set --steps-per-epoch when using --in-memory!')

train_data = SimpleIterDataset(train_file_dict, args.data_config, for_training=True,
remake_weights=not args.no_remake_weights,
load_range_and_fraction=(train_range, args.data_fraction),
file_fraction=args.file_fraction,
fetch_by_files=args.fetch_by_files,
Expand Down
15 changes: 9 additions & 6 deletions weaver/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,15 @@ def __init__(self, file_dict, data_config_file, for_training=True, load_range_an
self._sampler_options.update(training=False, shuffle=False, reweight=False)

# discover auto-generated reweight file
data_config_md5 = _md5(data_config_file)
data_config_autogen_file = data_config_file.replace('.yaml', '.%s.auto.yaml' % data_config_md5)
if os.path.exists(data_config_autogen_file):
data_config_file = data_config_autogen_file
_logger.info('Found file %s w/ auto-generated preprocessing information, will use that instead!' %
data_config_file)
if '.auto.yaml' in data_config_file:
data_config_autogen_file = data_config_file
else:
data_config_md5 = _md5(data_config_file)
data_config_autogen_file = data_config_file.replace('.yaml', '.%s.auto.yaml' % data_config_md5)
if os.path.exists(data_config_autogen_file):
data_config_file = data_config_autogen_file
_logger.info('Found file %s w/ auto-generated preprocessing information, will use that instead!' %
data_config_file)

# load data config (w/ observers now -- so they will be included in the auto-generated yaml)
self._data_config = DataConfig.load(data_config_file)
Expand Down

0 comments on commit 9254354

Please sign in to comment.