Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Multimodal #71

Draft
wants to merge 132 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
4c07734
multimodal
Jaesun0912 Aug 14, 2024
b4e6af8
Merge remote-tracking branch 'origin/main' into multimodel
YutackPark Aug 15, 2024
8cba28e
pre-commit done
Jaesun0912 Aug 16, 2024
1a9ff1e
chore: sync main, __version__
YutackPark Aug 19, 2024
cbf1c0c
fix: shift bug in model build
Jaesun0912 Aug 19, 2024
4cf7749
Create CHANGELOG.md
Jaesun0912 Oct 22, 2024
bebb601
fix: naive merge fix, not debugged, functionalities broken, now
YutackPark Oct 26, 2024
df95326
refactor+WIP: modal patch for linear
YutackPark Oct 27, 2024
5346e7e
Add comment to convert_model_modality.py
Jaesun0912 Oct 28, 2024
85ea475
add: dict reader
YutackPark Oct 30, 2024
d41c1f8
WIP add: multimodal dataset
YutackPark Oct 30, 2024
e755476
chore
YutackPark Oct 30, 2024
d640464
WIP refactor: multimodal, modality using onehot, modal map in sequent…
YutackPark Oct 30, 2024
dce6617
refactor: for dict reader
YutackPark Oct 30, 2024
4bd1b7d
chore
YutackPark Oct 31, 2024
6cc917b
fix: for dict converts string, device is needed to send it
YutackPark Oct 31, 2024
5c3486a
refactor: single pt specialized routines + dataweight
YutackPark Oct 31, 2024
75747c7
fix,add: able to read single pt, with updating data weight
YutackPark Oct 31, 2024
fd9d5af
refactor: error recorder, loss definitions, record error directly usi…
YutackPark Nov 1, 2024
22170f2
refactor: delete_unlabeled -> ignore_unlabeled
YutackPark Nov 2, 2024
fd5f1d6
refactor: use_modality, get optional
YutackPark Nov 2, 2024
5ec9b89
fix,refactor: from mapper
YutackPark Nov 2, 2024
40bb640
fix: linear
YutackPark Nov 2, 2024
96286f0
fix: support modality model continue, some logging, refactor graph da…
YutackPark Nov 2, 2024
282ce6e
fix,refactor: inference
YutackPark Nov 2, 2024
5b40082
fix: modal dataset
YutackPark Nov 2, 2024
50c65b3
fix
YutackPark Nov 2, 2024
cb46358
add: modal dataset train
YutackPark Nov 2, 2024
5940e3e
chore
YutackPark Nov 2, 2024
08f4056
fix: modal + inference
YutackPark Nov 2, 2024
960c389
refactor: no batch data for scale
YutackPark Nov 2, 2024
da55095
chore
YutackPark Nov 2, 2024
c4011f7
chore, change: remove dead code + print model config warning only whe…
YutackPark Nov 2, 2024
56a7796
fix,refactor: raise error for unknown chem + modality
YutackPark Nov 2, 2024
65148ae
add: use_weight for atoms dataset
YutackPark Nov 2, 2024
a2ac6ce
Merge branch 'mm_merge' into multimodal
YutackPark Nov 2, 2024
3833931
Merge branch 'dev' into multimodal
YutackPark Nov 2, 2024
d80b80d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
a11118c
version up + dev suffix
YutackPark Nov 5, 2024
4287ca6
bugfix: replace pt, .pt
YutackPark Nov 6, 2024
467c834
refactor: better error message for graph dataset
YutackPark Nov 6, 2024
26e9d02
refactor: drop info option, fix
YutackPark Nov 6, 2024
f119832
chore
YutackPark Nov 6, 2024
1289333
lint,fix
YutackPark Nov 6, 2024
cbec720
chore
YutackPark Nov 6, 2024
54004b0
refactor: more robust dataset naming
YutackPark Nov 6, 2024
fa7d077
fix: save real first epoch cp
YutackPark Nov 23, 2024
2c12601
refactor: util model_from_checkpont
YutackPark Nov 23, 2024
79defb8
merge dev
YutackPark Nov 23, 2024
05af5ac
Merge branch 'multimodal' into cu_equi
YutackPark Nov 23, 2024
05e64c8
refactor: sort convolution instructions, and fix flipped w3j coeff of…
YutackPark Nov 23, 2024
db5932a
refactor: _broadcast > broadcast
YutackPark Nov 24, 2024
faba81a
chore: trained -> deployed
YutackPark Nov 24, 2024
597800c
refactor: deploy takes checkpoint path
YutackPark Nov 24, 2024
0f3452c
docs: changelog
YutackPark Nov 24, 2024
aa07c9b
add: cuEquivariance support
YutackPark Nov 25, 2024
e08b425
Merge branch 'main' into cu_equi
YutackPark Nov 25, 2024
924c2e1
sync main
YutackPark Nov 25, 2024
333190c
mistake
YutackPark Nov 26, 2024
3f7017b
fix: bugfix, sort conv only if given version is below 0.12.0
YutackPark Nov 26, 2024
0d6bbbe
chore: trivials
YutackPark Nov 26, 2024
4e6e56d
refactor: separete preprocess
YutackPark Nov 27, 2024
a0b15ca
bugfix,add: fix parity=True for cueq (thanks to MACE), add cueq tests…
YutackPark Nov 27, 2024
c03cf0c
chore
YutackPark Nov 27, 2024
6260ccb
add,fix: Fully connected tp + params -> kwargs
YutackPark Nov 27, 2024
fbd17e0
fix tests
YutackPark Nov 27, 2024
0caac3f
solve stash conflict
YutackPark Nov 27, 2024
0cc0514
fall back to e3nn only when cuda is available
YutackPark Nov 27, 2024
e3e9345
chore: cue -> cueq
YutackPark Nov 27, 2024
2c4976e
fix: merge err
YutackPark Nov 27, 2024
5c61f9c
refactor: lmax edge
YutackPark Nov 27, 2024
8d8b160
fix: multi gpu training for edge cases
YutackPark Nov 27, 2024
029edba
Merge branch 'cu_equi' of github.com:MDIL-SNU/SevenNet into cu_equi
YutackPark Nov 27, 2024
2d2a6a2
e3nn > 0.5.0, due to changed CG coeff convention
YutackPark Nov 30, 2024
22e43aa
restore version
YutackPark Nov 30, 2024
facdb98
fix: old ver
YutackPark Nov 30, 2024
d14c0e2
remove old test cp
YutackPark Nov 30, 2024
0e54266
refactor: add overload for model build to quite pyright
YutackPark Nov 30, 2024
ae89660
refactor,add: checkpoint wrapper
YutackPark Nov 30, 2024
3aa15bc
add,fix: multimodal model with lammps
YutackPark Nov 30, 2024
79e2cf0
chore
YutackPark Nov 30, 2024
cc2969b
add: save time and hash key for checkpoint
YutackPark Nov 30, 2024
a4c72c8
chore
YutackPark Dec 1, 2024
4da8a20
refactor: some changes in _const order
YutackPark Dec 1, 2024
16fd041
add: sevenn_cp tool
YutackPark Dec 1, 2024
6b6c267
bugfix: scale from mapper was not working, but saved my like it was a…
YutackPark Dec 1, 2024
a86d4a1
bugfix: multimodal training dumping atom_type key after changing oneh…
YutackPark Dec 1, 2024
ff7c4bc
refactor: more robust irreps_tp_out
YutackPark Dec 3, 2024
6f4419d
Merge branch 'multimodal' of github.com:MDIL-SNU/SevenNet into multim…
YutackPark Dec 7, 2024
e943970
merge main
YutackPark Dec 7, 2024
8c1e1a2
add: modality append
YutackPark Dec 8, 2024
97372e0
refactor
YutackPark Dec 8, 2024
3776258
add: modal related tests
YutackPark Dec 8, 2024
5ed3e09
bugfix
YutackPark Dec 8, 2024
e2096cf
add: multimodal preset
YutackPark Dec 8, 2024
3c9e9ff
fix: mm preset
YutackPark Dec 8, 2024
802b280
changelog
YutackPark Dec 8, 2024
8357f5a
fix: pandas dep
YutackPark Dec 8, 2024
c35cfff
fix
YutackPark Dec 11, 2024
13d046f
merge main
YutackPark Dec 11, 2024
0b7cbf4
version up + raise Error not warning if version is not found
YutackPark Dec 11, 2024
56dd6ac
merge main
YutackPark Dec 11, 2024
75c8b72
fix: backward version checkp
YutackPark Dec 12, 2024
4dbfdb0
Merge branch 'dev' into multimodal
YutackPark Dec 12, 2024
6c18ffc
fix: non-species wise shift scale continue
YutackPark Dec 14, 2024
30410fa
Merge branch 'multimodal' of github.com:MDIL-SNU/SevenNet into multim…
YutackPark Dec 14, 2024
12f3897
Merge pull request #147 from MDIL-SNU/main
YutackPark Dec 16, 2024
7d1477e
Update README.md
Jaesun0912 Dec 16, 2024
a93c51e
Update README.md
Jaesun0912 Dec 16, 2024
a9cd332
Update README.md
Jaesun0912 Dec 16, 2024
443bcf9
Update README.md
Jaesun0912 Dec 16, 2024
ce2b894
Update README.md
Jaesun0912 Dec 16, 2024
2c1347e
(docs) Upload 7net-MF-0 and comment yaml
Jaesun0912 Dec 16, 2024
49c1e71
(feat) load 7net-MF-0 and test it
Jaesun0912 Dec 16, 2024
0918be6
Update README.md
Jaesun0912 Dec 16, 2024
401343b
Update README.md
Jaesun0912 Dec 16, 2024
3f12d6f
Update README.md
Jaesun0912 Dec 17, 2024
c87ee2b
Create README.md
Jaesun0912 Dec 17, 2024
6a83b10
Update README.md
Jaesun0912 Dec 17, 2024
5c10ce7
Update README.md
Jaesun0912 Dec 17, 2024
54f6756
chore
YutackPark Dec 17, 2024
dcd4a9b
fix: inference with .pt dataset
YutackPark Dec 17, 2024
ae5bde3
Update README.md
Jaesun0912 Dec 18, 2024
9350161
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2024
e31d845
Update README.md
Jaesun0912 Dec 18, 2024
526fc3f
fix: f
YutackPark Dec 18, 2024
019538e
add: lammps modal test
YutackPark Dec 18, 2024
fec3e93
Merge branch 'multimodal' of github.com:MDIL-SNU/SevenNet into multim…
YutackPark Dec 18, 2024
f0eb515
Update README.md
Jaesun0912 Dec 23, 2024
6214a8d
Update README.md
Jaesun0912 Dec 23, 2024
72dabbb
fix: > 2 pt file in the filelist
YutackPark Dec 24, 2024
41d3831
Merge branch 'multimodal' of github.com:MDIL-SNU/SevenNet into multim…
YutackPark Dec 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
# Changelog
All notable changes to this project will be documented in this file.

## WIP [0.11.0]
### Added
- Build multi-fidelity model, SevenNet-MF, based on given modality in the yaml
- Modality support for sevenn_inference, sevenn_get_modal, and SevenNetCalculator
- [cli] sevenn_cp tool for checkpoint summary, input generation, multi-modal routines
- Modality append / assign using sevenn_cp
- Loss weighting for energy, force and stress for corresponding data label
- Ignore unlabelled data when calculating loss. (e.g. stress data for non-pbc structure)
- Dict style dataset input for multi-modal and data-weight

### Added (code)
- sevenn.train.modal_dataset SevenNetMultiModalDataset
- sevenn.scripts.backward_compatibility.py
- sevenn.checkpoint.py

### Changed
- Sort instructions of tensor product in convolution (+ fix flipped w3j coeff of old model)
- Lazy initialization for `IrrepsLinear` and `SelfConnection*`
- Checkpoint things using `sevenn/checkpoint.py`
- e3nn >= 0.5.0, to ensure changed CG coeff later on
- pandas as dependency

### Fixed
- More refactor for shift scale things + few bug fixes

## [0.10.3]
### Added
Expand Down
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ SevenNet (Scalable EquiVariance Enabled Neural Network) is a graph neural networ
- Python [Atomic Simulation Environment (ASE)](https://wiki.fysik.dtu.dk/ase/) calculator support
- GPU-parallelized molecular dynamics with LAMMPS
- CUDA-accelerated D3 (van der Waals) dispersion
- Multi-fidelity training for combining multiple database with different calculation settings.

## Pre-trained models
So far, we have released three pre-trained SevenNet models. Each model has various hyperparameters and training sets, resulting in different accuracy and speed. Please read the descriptions below carefully and choose the model that best suits your purpose.
Expand All @@ -27,6 +28,18 @@ Additionally, `keywords` can be called in other parts of SevenNet, such as `seve

---

### **SevenNet-MF-0 (16Dec2024)**
> Keywords in ASE: `7net-MF-0` and `SevenNet-MF-0`

The model is trained on PBE (+U) and $\mathrm{r}^{2}$ SCAN database provided in Materials Project.
It has the same architecture with **SevenNet-0 (11Jul2024)**, except this model contains additional 'fidelity-dependent' parameters utilized for multi-fidelity training.
However, overhead of calculations regarding fidelity-dependent parameters are negligible, which results in almost the same inference speed with **SevenNet-0 (11Jul2024)**.

Details in using this model as well as choosing level-of-theory for inference can be found in [here](./sevenn/pretrained_potentials/SevenNet_MF_0).

* Training set MAE ($\mathrm{r}^{2}$ SCAN): 10.8 meV/atom (energy), 0.018 eV/Ang. (force), and 0.58 kbar (stress)
---

### **SevenNet-l3i5 (12Dec2024)**
> Keywords in ASE: `7net-l3i5` and `SevenNet-l3i5`

Expand Down Expand Up @@ -117,7 +130,7 @@ With the `sevenn_preset` command, the input file that sets the training paramete
sevenn_preset {preset keyword} > input.yaml
```

Available preset keywords are: `base`, `fine_tune`, `sevennet-0`, and `sevennet-l3i5`.
Available preset keywords are: `base`, `fine_tune`, `multi_modal`, `sevennet-0`, and `sevennet-l3i5`.
Check comments in the preset yaml files for explanations. For fine-tuning, note that most model hyperparameters cannot be modified unless explicitly indicated.
To reuse a preprocessed training set, you can specify `sevenn_data/${dataset_name}.pt` to the `load_trainset_path:` in the `input.yaml`.

Expand Down Expand Up @@ -307,3 +320,16 @@ If you use this code, please cite our paper:
pages = {4857--4868},
}
```

If you utilize the multi-fidelity feature of this code or the pretrained model SevenNet-MF-0, please cite the following paper:
```txt
@article{kim_sevennet_mf_2024,
title = {Data-Efficient Multifidelity Training for High-Fidelity Machine Learning Interatomic Potentials},
volume = {xx},
doi = {10.1021/jacs.4c14455},
number = {xx},
journal = {J. Am. Chem. Soc.},
author = {Kim, Jaesun and Kim, Jisu and Kim, Jaehoon and Lee, Jiho and Park, Yutack and Kang, Youngho and Han, Seungwu},
year = {2024},
pages = {xx--xx},
```
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sevenn"
version = "0.10.3"
version = "0.11.0"
authors = [
{ name = "Yutack Park", email = "[email protected]" },
{ name = "Haekwan Jeon", email = "[email protected]" },
Expand All @@ -21,12 +21,13 @@ dependencies = [
"ase",
"braceexpand",
"pyyaml",
"e3nn",
"e3nn>=0.5.0",
"tqdm",
"scikit-learn",
"torch_geometric>=2.5.0",
"numpy<2.0",
"matscipy",
"pandas",
]
[project.optional-dependencies]
test = ["matscipy", "pytest-cov>=5"]
Expand All @@ -39,6 +40,7 @@ sevenn_graph_build = "sevenn.main.sevenn_graph_build:main"
sevenn_inference = "sevenn.main.sevenn_inference:main"
sevenn_patch_lammps = "sevenn.main.sevenn_patch_lammps:main"
sevenn_preset = "sevenn.main.sevenn_preset:main"
sevenn_cp = "sevenn.main.sevenn_cp:main"

[project.urls]
Homepage = "https://github.com/MDIL-SNU/SevenNet"
Expand All @@ -59,6 +61,7 @@ sevenn = [
"pretrained_potentials/SevenNet_0__11Jul2024/checkpoint_sevennet_0.pth",
"pretrained_potentials/SevenNet_0__22May2024/checkpoint_sevennet_0.pth",
"pretrained_potentials/SevenNet_l3i5/checkpoint_l3i5.pth",
"pretrained_potentials/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth",
"py.typed",
]

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ include_trailing_comma=True
force_grid_wrap=0
use_parentheses=True
line_length=80
known_third_party=ase,braceexpand,e3nn,numpy,pytest,setuptools,sklearn,torch,torch_geometric,tqdm,yaml
known_third_party=ase,braceexpand,e3nn,numpy,pandas,pytest,setuptools,sklearn,torch,torch_geometric,tqdm,yaml
known_first_party=
45 changes: 37 additions & 8 deletions sevenn/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@
SEVENNET_l3i5 = (
f'{_prefix}/SevenNet_l3i5/checkpoint_l3i5.pth'
)
SEVENNET_MF_0 = (
f'{_prefix}/SevenNet_MF_0/checkpoint_sevennet_mf_0.pth'
)


# to avoid torch script to compile torch_geometry.data
AtomGraphDataType = Dict[str, torch.Tensor]


class LossType(Enum):
class LossType(Enum): # only used for train_v1, do not use it afterwards
ENERGY = 'energy' # eV or eV/atom
FORCE = 'force' # eV/A
STRESS = 'stress' # kB
Expand All @@ -80,44 +83,48 @@ def error_record_condition(x):
if v[0] == 'TotalLoss':
continue
if v[1] not in SUPPORTING_METRICS:
print('w')
return False
return True


DEFAULT_E3_EQUIVARIANT_MODEL_CONFIG = {
KEY.IRREPS_MANUAL: False,
KEY.CUTOFF: 4.5,
KEY.NODE_FEATURE_MULTIPLICITY: 32,
KEY.IRREPS_MANUAL: False,
KEY.LMAX: 1,
KEY.LMAX_EDGE: -1, # -1 means lmax_edge = lmax
KEY.LMAX_NODE: -1, # -1 means lmax_node = lmax
KEY.IS_PARITY: True,
KEY.NUM_CONVOLUTION: 3,
KEY.RADIAL_BASIS: {
KEY.RADIAL_BASIS_NAME: 'bessel',
},
KEY.CUTOFF_FUNCTION: {
KEY.CUTOFF_FUNCTION_NAME: 'poly_cut',
},
KEY.ACTIVATION_RADIAL: 'silu',
KEY.CUTOFF: 4.5,
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64],
KEY.NUM_CONVOLUTION: 3,
KEY.ACTIVATION_SCARLAR: {'e': 'silu', 'o': 'tanh'},
KEY.ACTIVATION_GATE: {'e': 'silu', 'o': 'tanh'},
KEY.CONVOLUTION_WEIGHT_NN_HIDDEN_NEURONS: [64, 64],
# KEY.AVG_NUM_NEIGH: True, # deprecated
# KEY.TRAIN_AVG_NUM_NEIGH: False, # deprecated
KEY.CONV_DENOMINATOR: 'avg_num_neigh',
KEY.TRAIN_DENOMINTAOR: False,
KEY.TRAIN_SHIFT_SCALE: False,
# KEY.OPTIMIZE_BY_REDUCE: True, # deprecated, always True
KEY.USE_BIAS_IN_LINEAR: False,
KEY.USE_MODAL_NODE_EMBEDDING: False,
KEY.USE_MODAL_SELF_INTER_INTRO: False,
KEY.USE_MODAL_SELF_INTER_OUTRO: False,
KEY.USE_MODAL_OUTPUT_BLOCK: False,
KEY.READOUT_AS_FCN: False,
# Applied af readout as fcn is True
KEY.READOUT_FCN_HIDDEN_NEURONS: [30, 30],
KEY.READOUT_FCN_ACTIVATION: 'relu',
KEY.SELF_CONNECTION_TYPE: 'nequip',
KEY.INTERACTION_TYPE: 'nequip',
KEY._NORMALIZE_SPH: True,
KEY.CUEQUIVARIANCE_CONFIG: {},
}


Expand All @@ -144,13 +151,18 @@ def error_record_condition(x):
KEY.TRAIN_SHIFT_SCALE: bool,
KEY.TRAIN_DENOMINTAOR: bool,
KEY.USE_BIAS_IN_LINEAR: bool,
KEY.USE_MODAL_NODE_EMBEDDING: bool,
KEY.USE_MODAL_SELF_INTER_INTRO: bool,
KEY.USE_MODAL_SELF_INTER_OUTRO: bool,
KEY.USE_MODAL_OUTPUT_BLOCK: bool,
KEY.READOUT_AS_FCN: bool,
KEY.READOUT_FCN_HIDDEN_NEURONS: list,
KEY.READOUT_FCN_ACTIVATION: str,
KEY.ACTIVATION_RADIAL: str,
KEY.SELF_CONNECTION_TYPE: lambda x: x in IMPLEMENTED_SELF_CONNECTION_TYPE,
KEY.INTERACTION_TYPE: lambda x: x in IMPLEMENTED_INTERACTION_TYPE,
KEY._NORMALIZE_SPH: bool,
KEY.CUEQUIVARIANCE_CONFIG: dict,
}


Expand Down Expand Up @@ -179,8 +191,13 @@ def model_defaults(config):
KEY.COMPUTE_STATISTICS: True,
KEY.DATASET_TYPE: 'graph',
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: False,
KEY.USE_MODAL_WISE_SHIFT: False,
KEY.USE_MODAL_WISE_SCALE: False,
KEY.SHIFT: 'per_atom_energy_mean',
KEY.SCALE: 'force_rms',
# KEY.DATA_SHUFFLE: True,
# KEY.DATA_WEIGHT: False,
# KEY.DATA_MODALITY: False,
}

DATA_CONFIG_CONDITION = {
Expand All @@ -197,8 +214,12 @@ def model_defaults(config):
# KEY.USE_SPECIES_WISE_SHIFT_SCALE: bool,
KEY.SHIFT: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SHIFT,
KEY.SCALE: lambda x: type(x) in [float, list] or x in IMPLEMENTED_SCALE,
KEY.USE_MODAL_WISE_SHIFT: bool,
KEY.USE_MODAL_WISE_SCALE: bool,
# KEY.DATA_SHUFFLE: bool,
KEY.COMPUTE_STATISTICS: bool,
KEY.SAVE_DATASET: str,
# KEY.DATA_WEIGHT: bool,
# KEY.DATA_MODALITY: bool,
}


Expand All @@ -221,14 +242,16 @@ def data_defaults(config):
KEY.FORCE_WEIGHT: 0.1,
KEY.STRESS_WEIGHT: 1e-6, # SIMPLE-NN default
KEY.PER_EPOCH: 5,
KEY.USE_TESTSET: False,
# KEY.USE_TESTSET: False,
KEY.CONTINUE: {
KEY.CHECKPOINT: False,
KEY.RESET_OPTIMIZER: False,
KEY.RESET_SCHEDULER: False,
KEY.RESET_EPOCH: False,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: True,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: True,
},
# KEY.DEFAULT_MODAL: 'common',
KEY.CSV_LOG: 'log.csv',
KEY.NUM_WORKERS: 0,
KEY.IS_TRAIN_STRESS: True,
Expand All @@ -240,6 +263,8 @@ def data_defaults(config):
['TotalLoss', 'None'],
],
KEY.BEST_METRIC: 'TotalLoss',
KEY.USE_WEIGHT: False,
KEY.USE_MODALITY: False,
}


Expand All @@ -257,12 +282,16 @@ def data_defaults(config):
KEY.RESET_SCHEDULER: bool,
KEY.RESET_EPOCH: bool,
KEY.USE_STATISTIC_VALUES_OF_CHECKPOINT: bool,
KEY.USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY: bool,
},
KEY.DEFAULT_MODAL: str,
KEY.IS_TRAIN_STRESS: bool,
KEY.TRAIN_SHUFFLE: bool,
KEY.ERROR_RECORD: error_record_condition,
KEY.BEST_METRIC: str,
KEY.CSV_LOG: str,
KEY.USE_MODALITY: bool,
KEY.USE_WEIGHT: bool,
}


Expand Down
31 changes: 28 additions & 3 deletions sevenn/_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from typing import Final

from torch.jit import CompilationUnit

# see
# https://github.com/pytorch/pytorch/issues/52312
# for FYI
Expand All @@ -40,6 +38,10 @@
NODE_FEATURE: Final[str] = 'x' # (N, ?) PyG
NODE_FEATURE_GHOST: Final[str] = 'x_ghost'
NODE_ATTR: Final[str] = 'node_attr' # (N, N_species) from one_hot
MODAL_ATTR: Final[str] = (
'modal_attr' # (1, N_modalities) for handling multi-modal
)
MODAL_TYPE: Final[str] = 'modal_type' # (1) one-hot index of modal
EDGE_ATTR: Final[str] = 'edge_attr' # (from spherical harmonics)
EDGE_EMBEDDING: Final[str] = 'edge_embedding' # (from edge embedding)

Expand Down Expand Up @@ -69,7 +71,11 @@
NUM_ATOMS: Final[str] = 'num_atoms' # int
NUM_GHOSTS: Final[str] = 'num_ghosts'
NLOCAL: Final[str] = 'nlocal' # only for lammps parallel, must be on cpu
USER_LABEL: Final[str] = 'user_label' # Deprecated from v0.9.6
USER_LABEL: Final[str] = 'user_label'
DATA_WEIGHT: Final[str] = 'data_weight' # weight for given data
DATA_MODALITY: Final[str] = (
'data_modality' # modality of given data. e.g. PBE and SCAN
)
BATCH: Final[str] = 'batch'

TAG = 'tag' # replace USER_LABEL
Expand Down Expand Up @@ -125,6 +131,9 @@
RESET_SCHEDULER = 'reset_scheduler'
RESET_EPOCH = 'reset_epoch'
USE_STATISTIC_VALUES_OF_CHECKPOINT = 'use_statistic_values_of_checkpoint'
USE_STATISTIC_VALUES_FOR_CP_MODAL_ONLY = (
'use_statistic_values_for_cp_modal_only'
)

CSV_LOG = 'csv_log'

Expand All @@ -140,6 +149,10 @@
DDP_BACKEND = 'ddp_backend'
PER_EPOCH = 'per_epoch'

USE_WEIGHT = 'use_weight'
USE_MODALITY = 'use_modality'
DEFAULT_MODAL = 'default_modal'


# ==================================================#
# ~~~~~~~~ KEY for model configuration ~~~~~~~~~~~ #
Expand All @@ -149,9 +162,12 @@
MODEL_TYPE = '_model_type'
CUTOFF = 'cutoff'
CHEMICAL_SPECIES = 'chemical_species'
MODAL_LIST = 'modal_list'
CHEMICAL_SPECIES_BY_ATOMIC_NUMBER = '_chemical_species_by_atomic_number'
NUM_SPECIES = '_number_of_species'
NUM_MODALITIES = '_number_of_modalities'
TYPE_MAP = '_type_map'
MODAL_MAP = '_modal_map'

# ~~ E3 equivariant model build configuration keys ~~ #
# see model_build default_config for type
Expand Down Expand Up @@ -181,6 +197,11 @@

USE_BIAS_IN_LINEAR = 'use_bias_in_linear'

USE_MODAL_NODE_EMBEDDING = 'use_modal_node_embedding'
USE_MODAL_SELF_INTER_INTRO = 'use_modal_self_inter_intro'
USE_MODAL_SELF_INTER_OUTRO = 'use_modal_self_inter_outro'
USE_MODAL_OUTPUT_BLOCK = 'use_modal_output_block'

READOUT_AS_FCN = 'readout_as_fcn'
READOUT_FCN_HIDDEN_NEURONS = 'readout_fcn_hidden_neurons'
READOUT_FCN_ACTIVATION = 'readout_fcn_activation'
Expand All @@ -191,11 +212,15 @@
SCALE = 'scale'

USE_SPECIES_WISE_SHIFT_SCALE = 'use_species_wise_shift_scale'
USE_MODAL_WISE_SHIFT = 'use_modal_wise_shift'
USE_MODAL_WISE_SCALE = 'use_modal_wise_scale'

TRAIN_SHIFT_SCALE = 'train_shift_scale'
TRAIN_DENOMINTAOR = 'train_denominator'
INTERACTION_TYPE = 'interaction_type'
TRAIN_AVG_NUM_NEIGH = 'train_avg_num_neigh' # deprecated

CUEQUIVARIANCE_CONFIG = 'cuequivariance_config'

_NORMALIZE_SPH = '_normalize_sph'
OPTIMIZE_BY_REDUCE = 'optimize_by_reduce'
Loading