Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New ML energy solvers: Nequip and MLIP-3 #73

Merged
merged 16 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions .github/workflows/Test_abICS.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ jobs:
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
testname: [Unit, Sampling, ActiveLearn]
testname: [Unit, Sampling, ActiveLearnAenet, ActiveLearnNequip, ActiveLearnMLIP-3]
exclude:
- python-version: 3.7
testname: ActiveLearnNequip
fail-fast: false

steps:
Expand Down Expand Up @@ -42,8 +45,14 @@ jobs:
cd ../potts_pamc
sh ./run.sh
;;
ActiveLearn ) cd tests/integration/active_learn
sh ./install_aenet.sh
sh ./run.sh ;;
ActiveLearnAenet ) cd tests/integration/active_learn_aenet
sh ./install_aenet.sh
sh ./run.sh ;;
ActiveLearnNequip ) cd tests/integration/active_learn_nequip
sh ./install_nequip.sh
sh ./run.sh ;;
ActiveLearnMLIP-3 ) cd tests/integration/active_learn_mlip3
sh ./install_mlip3.sh
sh ./run.sh ;;
* ) echo "Unknown testname";;
esac
22 changes: 14 additions & 8 deletions abics/applications/latgas_abinitio_interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

# from .default_observer import *
from .map2perflat import *
from .aenet_trainer import *

from .vasp import VASPSolver
from .qe import QESolver
from .aenet import AenetSolver
from .aenet_pylammps import AenetPyLammpsSolver
from .openmx import OpenMXSolver
from .user_function_solver import UserFunctionSolver
from .base_solver import register_solver
from .base_trainer import register_trainer

register_solver("vasp", "VASPSolver", "abics.applications.latgas_abinitio_interface.vasp")
register_solver("qe", "QESolver", "abics.applications.latgas_abinitio_interface.qe")
register_solver("openmx", "OpenMXSolver", "abics.applications.latgas_abinitio_interface.openmx")
register_solver("aenet", "AenetSolver", "abics.applications.latgas_abinitio_interface.aenet")
register_solver("nequip", "NequipSolver", "abics.applications.latgas_abinitio_interface.nequip")
register_solver("mlip_3", "MLIP3Solver", "abics.applications.latgas_abinitio_interface.mlip_3")
register_solver("User", "UserFunctionSolver", "abics.applications.latgas_abinitio_interface.user_function_solver")

register_trainer("aenet", "AenetTrainer", "abics.applications.latgas_abinitio_interface.aenet_trainer")
register_trainer("nequip", "NequipTrainer", "abics.applications.latgas_abinitio_interface.nequip_trainer")
register_trainer("mlip_3", "MLIP3Trainer", "abics.applications.latgas_abinitio_interface.mlip_3_trainer")
112 changes: 4 additions & 108 deletions abics/applications/latgas_abinitio_interface/aenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

"""
Adapted from pymatgen.io.xcrysden distributed under the MIT License
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""

from __future__ import annotations

import os
Expand All @@ -28,105 +22,9 @@
import numpy as np
from pymatgen.core import Structure

from .base_solver import SolverBase, register_solver
from .base_solver import SolverBase
from .params import ALParams, DFTParams


def to_XSF(structure: Structure, write_force_zero=False):
"""
Returns a string with the structure in XSF format
See http://www.xcrysden.org/doc/XSF.html
"""
lines = []
app = lines.append

app("CRYSTAL")
app("# Primitive lattice vectors in Angstrom")
app("PRIMVEC")
cell = structure.lattice.matrix
for i in range(3):
app(" %.14f %.14f %.14f" % tuple(cell[i]))

cart_coords = structure.cart_coords
app("# Cartesian coordinates in Angstrom.")
app("PRIMCOORD")
app(" %d 1" % len(cart_coords))
species = structure.species
site_properties = structure.site_properties
if "forces" not in site_properties.keys():
write_force_zero = True
else:
forces = site_properties["forces"]

if write_force_zero:
for a in range(len(cart_coords)):
app(
str(species[a])
+ " %20.14f %20.14f %20.14f" % tuple(cart_coords[a])
+ " 0.0 0.0 0.0"
)
else:
for a in range(len(cart_coords)):
app(
str(species[a])
+ " %20.14f %20.14f %20.14f" % tuple(cart_coords[a])
+ " %20.14f %20.14f %20.14f" % tuple(forces[a])
)

return "\n".join(lines)


def from_XSF(input_string: str):
"""
Initialize a `Structure` object from a string with data in XSF format.

Args:
input_string: String with the structure in XSF format.
See http://www.xcrysden.org/doc/XSF.html
cls_: Structure class to be created. default: pymatgen structure

"""
# CRYSTAL see (1)
# these are primitive lattice vectors (in Angstroms)
# PRIMVEC
# 0.0000000 2.7100000 2.7100000 see (2)
# 2.7100000 0.0000000 2.7100000
# 2.7100000 2.7100000 0.0000000

# these are conventional lattice vectors (in Angstroms)
# CONVVEC
# 5.4200000 0.0000000 0.0000000 see (3)
# 0.0000000 5.4200000 0.0000000
# 0.0000000 0.0000000 5.4200000

# these are atomic coordinates in a primitive unit cell (in Angstroms)
# PRIMCOORD
# 2 1 see (4)
# 16 0.0000000 0.0000000 0.0000000 see (5)
# 30 1.3550000 -1.3550000 -1.3550000

lattice, coords, species = [], [], []
lines = input_string.splitlines()

for i in range(len(lines)):
if "PRIMVEC" in lines[i]:
for j in range(i + 1, i + 4):
lattice.append([float(c) for c in lines[j].split()])

if "PRIMCOORD" in lines[i]:
num_sites = int(lines[i + 1].split()[0])

for j in range(i + 2, i + 2 + num_sites):
tokens = lines[j].split()
species.append(tokens[0])
coords.append([float(j) for j in tokens[1:4]])
break
else:
raise ValueError("Invalid XSF data")

s = Structure(lattice, species, coords, coords_are_cartesian=True)
return s

from .util import structure_to_XSF, structure_from_XSF

class AenetSolver(SolverBase):
"""
Expand Down Expand Up @@ -183,7 +81,7 @@ def update_info_by_structure(self, structure: Structure):
if self.ignore_species is not None:
structure = structure.copy()
structure.remove_species(self.ignore_species)
self.pos_info = to_XSF(structure)
self.pos_info = structure_to_XSF(structure)

def update_info_from_files(self, output_dir, rerun):
"""
Expand Down Expand Up @@ -262,7 +160,7 @@ def get_results(self, output_dir):
# Read results from files in output_dir and calculate values
Phys = namedtuple("PhysValues", ("energy", "structure"))
with open(os.path.join(output_dir, "structure.xsf")) as f:
structure = from_XSF(f.read())
structure = structure_from_XSF(f.read())
with open(os.path.join(output_dir, "stdout")) as f:
lines = f.read()
fi_io = io.StringIO(lines)
Expand Down Expand Up @@ -291,5 +189,3 @@ def create(cls, params: ALParams | DFTParams):
ignore_species = params.ignore_species
run_scheme = params.solver_run_scheme
return cls(path, ignore_species, run_scheme)

register_solver("aenet", AenetSolver)
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import numpy as np
from pymatgen.core import Structure

from .base_solver import SolverBase, register_solver
from .base_solver import SolverBase
from .params import ALParams, DFTParams


Expand Down Expand Up @@ -227,6 +227,3 @@ def solver_run_schemes(self):
def create(cls, params: ALParams | DFTParams):
ignore_species = params.ignore_species
return cls(ignore_species)


register_solver("aenetpylammps", AenetPyLammpsSolver)
36 changes: 26 additions & 10 deletions abics/applications/latgas_abinitio_interface/aenet_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# ab-Initio Configuration Sampling tool kit (abICS)
# Copyright (C) 2019- The University of Tokyo
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from __future__ import annotations
from typing import Sequence
from typing import Sequence, Dict

import numpy as np
import os, pathlib, shutil, subprocess, shlex
Expand All @@ -8,30 +24,30 @@
from pymatgen.core import Structure

from abics.util import expand_cmd_path
from abics.applications.latgas_abinitio_interface import aenet
from abics.applications.latgas_abinitio_interface.base_trainer import TrainerBase
from abics.applications.latgas_abinitio_interface.util import structure_to_XSF

class aenet_trainer:
class AenetTrainer(TrainerBase):
def __init__(
self,
structures: Sequence[Structure],
energies: Sequence[float],
generate_inputdir: os.PathLike,
train_inputdir: os.PathLike,
predict_inputdir: os.PathLike,
generate_exe: str,
train_exe: str,
execute_commands: Dict,
):
self.structures = structures
self.energies = energies
self.generate_inputdir = generate_inputdir
self.train_inputdir = train_inputdir
self.predict_inputdir = predict_inputdir
generate_exe = execute_commands["generate"]
self.generate_exe = [expand_cmd_path(e) for e in shlex.split(generate_exe)]
self.generate_exe.append("generate.in")
train_exe = execute_commands["train"]
self.train_exe = [expand_cmd_path(e) for e in shlex.split(train_exe)]
self.train_exe.append("train.in")
# self.generate_exe = generate_exe
# self.train_exe = train_exe
assert len(self.structures) == len(self.energies)
self.numdata = len(self.structures)
self.is_prepared = False
Expand All @@ -48,15 +64,15 @@ def prepare(self, latgas_mode = True, st_dir = "aenetXSF"):
xsfdir = os.getcwd()
if latgas_mode:
for i, st in enumerate(self.structures):
xsf_string = aenet.to_XSF(st, write_force_zero=False)
xsf_string = structure_to_XSF(st, write_force_zero=False)
xsf_string = (
"# total energy = {} eV\n\n".format(self.energies[i]) + xsf_string
)
with open("structure.{}.xsf".format(i), "w") as fi:
fi.write(xsf_string)
else:
for i, st in enumerate(self.structures):
xsf_string = aenet.to_XSF(st, write_force_zero=False)
xsf_string = structure_to_XSF(st, write_force_zero=False)
xsf_string = (
"# total energy = {} eV\n\n".format(self.energies[i]) + xsf_string
)
Expand Down Expand Up @@ -170,7 +186,7 @@ def train(self, train_dir = "train"):
os.chdir(pathlib.Path(os.getcwd()).parent)
self.is_trained = True

def new_baseinput(self, baseinput_dir):
def new_baseinput(self, baseinput_dir, train_dir=""):
try:
assert self.is_trained
except AssertionError as e:
Expand Down
20 changes: 13 additions & 7 deletions abics/applications/latgas_abinitio_interface/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,21 @@ def create(cls, params: ALParams | DFTParams) -> SolverBase:

__solver_table = {}

def register_solver(solver_name: str, solver_class) -> None:
def register_solver(solver_name: str, solver_class: str, solver_module: str) -> None:
"""
Register solver class.

Parameters
----------
solver_name : str
Solver name (case insensible).
solver_class : SolverBase
solver_class : str
Solver class, which should be a subclass of SolverBase.
solver_module : str
Module name including the solver class.
"""

if SolverBase not in solver_class.mro():
raise TypeError("solver_class must be a subclass of SolverBase")
__solver_table[solver_name.lower()] = solver_class
__solver_table[solver_name.lower()] = (solver_class, solver_module)


def create_solver(solver_name, params: ALParams | DFTParams) -> SolverBase:
Expand All @@ -236,5 +236,11 @@ def create_solver(solver_name, params: ALParams | DFTParams) -> SolverBase:
sn = solver_name.lower()
if sn not in __solver_table:
raise ValueError(f"Unknown solver: {solver_name}")
solver_class = __solver_table[sn]
return solver_class.create(params)

import importlib
solver_class_name, solver_module = __solver_table[sn]
mod = importlib.import_module(solver_module)
solver_class = getattr(mod, solver_class_name)
if SolverBase not in solver_class.mro():
raise TypeError("solver_class must be a subclass of SolverBase")
return solver_class.create(params)
Loading