-
Notifications
You must be signed in to change notification settings - Fork 20
/
config_finetune.yaml
30 lines (26 loc) · 2.15 KB
/
config_finetune.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
CV_flag: False # whether to use cross-validation
add_vocab_flag: True # whether to add supplementary vocab
LLRD_flag: False # whether to use layer-wise lr decay
aug_flag: False # whether to apply data augmentation
aug_special_flag: False # whether to augment copolymer SMILES
model_indicator: 'pretrain' # whether to use pretrained model
aug_indicator: # number of augmentation per SMILES. If no limitation, assign the indicator with None (leave it blank).
vocab_sup_file: 'data/vocab/vocab_sup_PE_I.csv' # supplementary vocab file path
train_file: 'data/train_PE_I.csv' # train file path
test_file: 'data/test_PE_I.csv' # test file path if cross-validation is not used
model_path: 'ckpt/pretrain.pt' # pretrain model path
save_path: 'ckpt/PE_I_train.pt' # checkpoint path
best_model_path: 'ckpt/PE_I_best_model.pt' # best model save path
k: 5 # k-fold cross-validation
blocksize: 411 # max length of sequences after tokenization
batch_size: 32 # batch size
num_epochs: 20 # total number of epochs
warmup_ratio: 0.05 # warmup ratio
drop_rate: 0.1 # dropout
lr_rate: 0.00005 # initial lr for LLRD and pretrained model lr otherwise
lr_rate_reg: 0.0001 # regressor lr if LLRD is not used
weight_decay: 0.01
hidden_dropout_prob: 0.1 # hidden layer dropout
attention_probs_dropout_prob: 0.1 # attention dropout
tolerance: 5 # tolerance of no improvement in performance before early stop
num_workers: 8 # number of workers when loading data