diff --git a/jaxreaxff/driver.py b/jaxreaxff/driver.py index e2f320c..81cbe8e 100755 --- a/jaxreaxff/driver.py +++ b/jaxreaxff/driver.py @@ -36,6 +36,7 @@ parse_and_save_force_field) import math from functools import partial +from jaxreaxff.helper import build_float_range_checker def main(): # create parser for command-line arguments @@ -133,6 +134,13 @@ def main(): help='R|Max number of clusters that can be used\n' + 'High number of clusters lowers the memory cost\n' + 'However, it increases compilation time,especially for cpus') + parser.add_argument('--perc_noise_when_stuck', metavar='percentage', + type=build_float_range_checker(0.0, 0.1), + default=0.04, + help='R|Percentage of the noise that will be added to the parameters\n' + + 'when the optimizer is stuck.\n' + + 'param_noise_i = (param_min_i, param_max_i) * perc_noise_when_stuck\n' + + 'Allowed range: [0.0, 0.1]') parser.add_argument('--seed', metavar='seed', type=int, default=0, @@ -149,9 +157,9 @@ def main(): print("To use the GPU version, jaxlib with CUDA support needs to installed!") # advanced options - advanced_opts = {"perc_err_change_thr":0.01, # if change in error is less than this threshold, add noise - "perc_noise_when_stuck":0.04, # noise percantage (wrt param range) to add when stuck - "perc_width_rest_search":0.15, # width of the restricted parameter search after iteration > rest_search_start + advanced_opts = {"perc_err_change_thr":0.01, # if change in error is less than this threshold, add noise + "perc_noise_when_stuck":args.perc_noise_when_stuck, # noise percantage (wrt param range) to add when stuck + "perc_width_rest_search":0.15, # width of the restricted parameter search after iteration > rest_search_start } onp.random.seed(args.seed) diff --git a/jaxreaxff/helper.py b/jaxreaxff/helper.py index c597189..b477375 100644 --- a/jaxreaxff/helper.py +++ b/jaxreaxff/helper.py @@ -20,11 +20,28 @@ from jaxreaxff.inter_list_counter import pool_handler_for_inter_list_count from jax_md import dataclasses from jax_md.reaxff.reaxff_forcefield import ForceField +import argparse + # Since we shouldnt access the private API (jaxlib), create a dummy jax array # and get the type information from the array. #from jaxlib.xla_extension import ArrayImpl as JaxArrayType JaxArrayType = type(jnp.zeros(1)) +def build_float_range_checker(min_v, max_v): + ''' + Returns a function that can be used to validate fiven FP value + withing the allowed range ([min_v, max_v]) + ''' + def range_checker(arg): + try: + val = float(arg) + except ValueError: + raise argparse.ArgumentTypeError("Value must be a floating point number") + if val < min_v or val > max_v: + raise argparse.ArgumentTypeError("Value must be in range [" + str(min_v) + ", " + str(max_v)+"]") + return val + return range_checker + def get_params(force_field, params_list): ''' Get the selected parameters from the force field