-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconfig_pretrain.yaml
79 lines (74 loc) · 1.64 KB
/
config_pretrain.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
gpu: 'cuda:0'
lr: 2e-4
min_lr: 1e-7
weight_decay: 0.0
epochs: 5
warmup_epochs: 0.7
patience_epochs: 0.3
load_model: None
log_every_n_steps: 50
# Select the GNN model. Supported models:
# - SchNet: K. T. Schütt et al., https://aip.scitation.org/doi/10.1063/1.5019779
# - SE(3)-Transformer: F. B. Fuchs et al., https://arxiv.org/abs/2006.10503
# - EGNN: V. G. Satorras et al., https://arxiv.org/abs/2102.09844
# - TorchMD-Net: P. Thölke et al., https://arxiv.org/abs/2202.02541
# The defualt settings are listed below.
model:
# name: "SchNet"
# num_atoms: 28
# bond_feat_dim: 1
# num_targets: 1
# hidden_channels: 256
# num_filters: 256
# num_interactions: 5
# num_gaussians: 32
# cutoff: 5.0
# max_num_neighbors: 32
# readout: 'add'
# name: 'SE3Transformer'
# num_layers: 3
# atom_feature_size: 28
# num_channels: 8
# num_degrees: 4
# edge_dim: 4
# div: 4
# pooling: 'avg'
# n_heads: 2
# cutoff: 5.0
# max_num_neighbors: 32
name: "EGNN"
hidden_channels: 256
in_edge_nf: 0
n_layers: 5
residual: True
attention: True
normalize: True
tanh: False
cutoff: 5.0
max_atom_type: 28
max_chirality_type: 5
max_num_neighbors: 32
# name: "TorchMD-Net"
# hidden_channels: 256
# num_layers: 6
# num_rbf: 32
# rbf_type: "expnorm"
# trainable_rbf: True
# activation: "silu"
# attn_activation: "silu"
# neighbor_embedding: True
# num_heads: 8
# distance_influence: "both"
# cutoff_lower: 0.0
# cutoff_upper: 5.0
# max_atom_type: 28
# max_chirality_type: 5
# max_num_neighbors: 32
dataset:
batch_size: 256
num_workers: 8
valid_size: 0.05
ani1: True
ani1x: True
std: 0.2
seed: 777