Skip to content

Commit

Permalink
change exe_command format
Browse files Browse the repository at this point in the history
  • Loading branch information
yomichi committed Nov 5, 2024
1 parent 82fda19 commit a04fcc5
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 37 deletions.
9 changes: 4 additions & 5 deletions abics/applications/latgas_abinitio_interface/aenet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import annotations
from typing import Sequence
from typing import Sequence, Dict

import numpy as np
import os, pathlib, shutil, subprocess, shlex
Expand All @@ -35,20 +35,19 @@ def __init__(
generate_inputdir: os.PathLike,
train_inputdir: os.PathLike,
predict_inputdir: os.PathLike,
generate_exe: str,
train_exe: str,
execute_commands: Dict,
):
self.structures = structures
self.energies = energies
self.generate_inputdir = generate_inputdir
self.train_inputdir = train_inputdir
self.predict_inputdir = predict_inputdir
generate_exe = execute_commands["generate"]
self.generate_exe = [expand_cmd_path(e) for e in shlex.split(generate_exe)]
self.generate_exe.append("generate.in")
train_exe = execute_commands["train"]
self.train_exe = [expand_cmd_path(e) for e in shlex.split(train_exe)]
self.train_exe.append("train.in")
# self.generate_exe = generate_exe
# self.train_exe = train_exe
assert len(self.structures) == len(self.energies)
self.numdata = len(self.structures)
self.is_prepared = False
Expand Down
5 changes: 2 additions & 3 deletions abics/applications/latgas_abinitio_interface/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Sequence, Type
from typing import Sequence, Type, Dict
import os

class TrainerBase(object):
Expand All @@ -28,8 +28,7 @@ def __init__(
generate_inputdir: os.PathLike,
train_inputdir: os.PathLike,
predict_inputdir: os.PathLike,
generate_exe: str,
train_exe: str,
execute_commands: Dict[str, str]
):
...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import shutil
import subprocess
import time
from typing import Sequence
from typing import Sequence, Dict

from pymatgen.core import Structure

Expand All @@ -43,14 +43,14 @@ def __init__(
generate_inputdir: os.PathLike,
train_inputdir: os.PathLike,
predict_inputdir: os.PathLike,
generate_exe: str,
train_exe: str,
execute_command: Dict,
):
self.structures = structures
self.energies = energies
self.generate_inputdir = generate_inputdir
self.train_inputdir = train_inputdir
self.predict_inputdir = predict_inputdir
train_exe = execute_command["train"]
self.train_exe = [
expand_cmd_path(e) for e in shlex.split(train_exe)
]
Expand Down
11 changes: 4 additions & 7 deletions abics/applications/latgas_abinitio_interface/nequip_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import annotations
from typing import Sequence
from typing import Sequence, Dict

import numpy as np
import os, pathlib, shutil, subprocess, shlex
Expand All @@ -33,6 +33,7 @@
import ase
from ase import io
from ase.calculators.singlepoint import SinglePointCalculator

from nequip.utils import Config
from nequip.scripts import deploy as nequip_deploy

Expand All @@ -54,21 +55,17 @@ def __init__(
generate_inputdir: os.PathLike,
train_inputdir: os.PathLike,
predict_inputdir: os.PathLike,
generate_exe: str,
train_exe: str,
execute_commands: Dict,
# trainer_type: str,
):
self.structures = structures
self.energies = energies
self.generate_inputdir = generate_inputdir
self.train_inputdir = train_inputdir
self.predict_inputdir = predict_inputdir
self.generate_exe = [expand_cmd_path(e) for e in shlex.split(generate_exe)]
self.generate_exe.append("generate.in")
train_exe = execute_commands["train"]
self.train_exe = [expand_cmd_path(e) for e in shlex.split(train_exe)]
self.train_exe.append("input.yaml")
# self.generate_exe = generate_exe
# self.train_exe = train_exe
assert len(self.structures) == len(self.energies)
self.numdata = len(self.structures)
self.is_prepared = False
Expand Down
13 changes: 11 additions & 2 deletions abics/applications/latgas_abinitio_interface/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,18 @@ def from_dict(cls, d):
)
params.solver = d["type"]
exe_command = d["exe_command"]
params.exe_command = {}
if isinstance(exe_command, str):
exe_command = [exe_command]
params.exe_command = exe_command
params.exe_command = {"train": exe_command}
elif isinstance(exe_command, list):
# For backward compatibility
for i, cmd in enumerate(exe_command):
if i == 0:
params.exe_command["generate"] = cmd
elif i == 1:
params.exe_command["train"] = cmd
elif isinstance(exe_command, dict):
params.exe_command = exe_command
params.solver_run_scheme = d.get("run_scheme", "subprocess")
params.ignore_species = d.get("ignore_species", None)
params.vac_map = d.get("vac_map", [])
Expand Down
6 changes: 1 addition & 5 deletions abics/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,6 @@ def main_impl(params_root: MutableMapping):
train_input_dirs.append(os.path.join(d, "train"))
predict_input_dirs.append(os.path.join(d, "predict"))

generate_exe = trainer_commands[0]
train_exe = trainer_commands[1]

trainer_class = get_trainer_class(trainer_type)
trainers = []
for i in range(len(trainer_input_dirs)):
Expand All @@ -230,8 +227,7 @@ def main_impl(params_root: MutableMapping):
generate_input_dirs[i],
train_input_dirs[i],
predict_input_dirs[i],
generate_exe,
train_exe,
trainer_commands,
)
)

Expand Down
8 changes: 4 additions & 4 deletions tests/integration/active_learn_aenet/input.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ perturb = 0.05
[train]
type = 'aenet'
base_input_dir = './aenet_train_input'
exe_command = [
'~/opt/aenet/bin/generate.x_serial',
'mpiexec -np 2 --oversubscribe ~/opt/aenet/bin/train.x_mpi'
]
ignore_species = ["O"]
vac_map = []
restart = false

[train.exe_command]
generate = '~/opt/aenet/bin/generate.x_serial'
train = 'mpiexec -np 2 --oversubscribe ~/opt/aenet/bin/train.x_mpi'

[config]
unitcell = [[8.1135997772, 0.0000000000, 0.0000000000],
[0.0000000000, 8.1135997772, 0.0000000000],
Expand Down
5 changes: 1 addition & 4 deletions tests/integration/active_learn_mlip3/input.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ perturb = 0.05
[train]
type = 'mlip_3'
base_input_dir = './mlip-3_train_input'
exe_command = [
'./mlip-3/bin/mlp',
'./mlip-3/bin/mlp',
]
exe_command = { train = './mlip-3/bin/mlp' }
ignore_species = ["O"]
vac_map = []
restart = false
Expand Down
5 changes: 1 addition & 4 deletions tests/integration/active_learn_nequip/input.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ perturb = 0.05
[train]
type = 'nequip'
base_input_dir = './allegro_train_input'
exe_command = [
'',
'nequip-train'
]
exe_command = {train = 'nequip-train'}
ignore_species = ["O"]
vac_map = []
restart = false
Expand Down

0 comments on commit a04fcc5

Please sign in to comment.