Skip to content

Commit

Permalink
update dpmd test
Browse files Browse the repository at this point in the history
  • Loading branch information
y1xiaoc committed Apr 13, 2021
1 parent ddef847 commit 887c92f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
67 changes: 47 additions & 20 deletions tests/dpmdargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def model_args ():
return ca


def learning_rate_args():
def learning_rate_exp():
doc_start_lr = 'The learning rate the start of the training.'
doc_stop_lr = 'The desired learning rate at the end of the training.'
doc_decay_steps = 'The learning rate is decaying every this number of training steps.'
Expand All @@ -326,9 +326,24 @@ def learning_rate_args():
Argument("stop_lr", float, optional = True, default = 1e-8, doc = doc_stop_lr),
Argument("decay_steps", int, optional = True, default = 5000, doc = doc_decay_steps)
]
return args


def learning_rate_variant_type_args():
doc_lr = 'The type of the learning rate. Current type `exp`, the exponentially decaying learning rate is supported.'

return Variant("type",
[Argument("exp", dict, learning_rate_exp())],
optional = True,
default_tag = 'exp',
doc = doc_lr)


doc_lr = "The learning rate options"
return Argument("learning_rate", dict, args, [], doc = doc_lr)
def learning_rate_args():
doc_lr = "The definitio of learning rate"
return Argument("learning_rate", dict, [],
[learning_rate_variant_type_args()],
doc = doc_lr)


def start_pref(item):
Expand Down Expand Up @@ -378,15 +393,16 @@ def loss_args():
return ca

def training_args():
link_sys = make_link("systems", "training/systems")
doc_systems = 'The data systems. This key can be provided with a listthat specifies the systems, or be provided with a string by which the prefix of all systems are given and the list of the systems is automatically generated.'
doc_set_prefix = 'The prefix of the sets in the systems.'
doc_set_prefix = f'The prefix of the sets in the {link_sys}.'
doc_stop_batch = 'Number of training batch. Each training uses one batch of data.'
doc_batch_size = 'This key can be \n\n\
- list: the length of which is the same as the `systems`. The batch size of each system is given by the elements of the list.\n\n\
- int: all `systems` uses the same batch size.\n\n\
- string "auto": automatically determines the batch size os that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size os that the batch_size times the number of atoms in the system is no less than N.'
doc_seed = 'The random seed for training.'
doc_batch_size = f'This key can be \n\n\
- list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\
- int: all {link_sys} use the same batch size.\n\n\
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.'
doc_seed = 'The random seed for getting frames from the training data set.'
doc_disp_file = 'The file for printing learning curve.'
doc_disp_freq = 'The frequency of printing learning curve.'
doc_numb_test = 'Number of frames used for the test during training.'
Expand All @@ -396,12 +412,21 @@ def training_args():
doc_time_training = 'Timing durining training.'
doc_profiling = 'Profiling during training.'
doc_profiling_file = 'Output file for profiling.'
doc_train_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\
- "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." : the list of systems is devided into blocks. A block is specified by `stt_idx:end_idx:weight`, where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system, the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional to the number of batches in the system.'
doc_train_sys_probs = "A list of float, should be of the same length as `train_systems`, specifying the probability of each system."
doc_tensorboard = 'Enable tensorboard'
doc_tensorboard_log_dir = 'The log directory of tensorboard outputs'

args = [
Argument("systems", [list,str], optional = False, doc = doc_systems),
Argument("systems", [list,str], optional = False, doc = doc_systems, alias = ["trn_systems"]),
Argument("set_prefix", str, optional = True, default = 'set', doc = doc_set_prefix),
Argument("stop_batch", int, optional = False, doc = doc_stop_batch),
Argument("batch_size", [list,int,str], optional = True, default = 'auto', doc = doc_batch_size),
Argument("auto_prob", str, optional = True, default = "prob_sys_size", doc = doc_train_auto_prob_style, alias = ["trn_auto_prob", "auto_prob_style"]),
Argument("sys_probs", list, optional = True, default = None, doc = doc_train_sys_probs, alias = ["trn_sys_probs"]),
Argument("batch_size", [list,int,str], optional = True, default = 'auto', doc = doc_batch_size, alias = ["trn_batch_size"]),
Argument("numb_steps", int, optional = False, doc = doc_stop_batch, alias = ["stop_batch"]),
Argument("seed", [int,None], optional = True, doc = doc_seed),
Argument("disp_file", str, optional = True, default = 'lcueve.out', doc = doc_disp_file),
Argument("disp_freq", int, optional = True, default = 1000, doc = doc_disp_freq),
Expand All @@ -411,7 +436,9 @@ def training_args():
Argument("disp_training", bool, optional = True, default = True, doc = doc_disp_training),
Argument("time_training", bool, optional = True, default = True, doc = doc_time_training),
Argument("profiling", bool, optional = True, default = False, doc = doc_profiling),
Argument("profiling_file", str, optional = True, default = 'timeline.json', doc = doc_profiling_file)
Argument("profiling_file", str, optional = True, default = 'timeline.json', doc = doc_profiling_file),
Argument("tensorboard", bool, optional = True, default = False, doc = doc_tensorboard),
Argument("tensorboard_log_dir", str, optional = True, default = 'log', doc = doc_tensorboard_log_dir),
]

doc_training = 'The training options'
Expand Down Expand Up @@ -493,14 +520,14 @@ def normalize(data):
},
"learning_rate" :{
"_type": "exp",
"type": "exp",
"decay_steps": 5000,
"start_lr": 0.001,
"stop_lr": 3.51e-8,
"_comment": "that's all"
},
},
"loss" :{
"loss" :{
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
Expand All @@ -526,11 +553,11 @@ def normalize(data):
"numb_test": 10,
"save_freq": 1000,
"save_ckpt": "model.ckpt",
"_load_ckpt": "model.ckpt",
"load_ckpt": "model.ckpt",
"disp_training":true,
"time_training":true,
"_tensorboard": false,
"_tensorboard_log_dir":"log",
"tensorboard": false,
"tensorboard_log_dir":"log",
"profiling": false,
"profiling_file":"timeline.json",
"_comment": "that's all"
Expand Down
7 changes: 0 additions & 7 deletions tests/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,6 @@ def test_sub_variants(self):
Argument("type2", dict)])
])

def test_dpmd(self):
import json
from dpmdargs import check, example_json_str
data = json.loads(example_json_str)
check(data)
# print("\n\n"+docstr)


if __name__ == "__main__":
unittest.main()
Expand Down

0 comments on commit 887c92f

Please sign in to comment.