Skip to content

Commit

Permalink
sketch of derivative model and builder
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 3, 2024
1 parent 956d816 commit 763bf14
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 19 deletions.
12 changes: 6 additions & 6 deletions apax/nn/jax/model/builder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np

from apax.config import ModelConfig
from apax.layers.descriptor.basis_functions import GaussianBasis, RadialFunction
from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.layers.empirical import ZBLRepulsion
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel
from apax.nn.jax.layers.descriptor.basis import GaussianBasis, RadialFunction
from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.nn.jax.layers.empirical import ZBLRepulsion
from apax.nn.jax.layers.readout import AtomisticReadout
from apax.nn.jax.layers.scaling import PerElementScaleShift
from apax.nn.jax.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel


class ModelBuilder:
Expand Down
2 changes: 0 additions & 2 deletions apax/nn/jax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def __call__(


class EnergyDerivativeModel(nn.Module):
# Alternatively, should this be a function transformation?
energy_model: EnergyModel = EnergyModel()
corrections: list[EmpiricalEnergyTerm] = field(default_factory=lambda: [])
calc_stress: bool = False

def __call__(
Expand Down
137 changes: 137 additions & 0 deletions apax/nn/torch/model/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import numpy as np

from apax.config import ModelConfig
from apax.nn.torch.layers.descriptor.basis import GaussianBasis, RadialFunction
from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
# from apax.nn.torch.layers.empirical import ZBLRepulsion
from apax.nn.torch.layers.readout import AtomisticReadout
from apax.nn.torch.layers.scaling import PerElementScaleShift
from apax.nn.torch.model.gmnn import AtomisticModel, EnergyDerivativeModel, EnergyModel


class ModelBuilder:
def __init__(self, model_config: ModelConfig, n_species: int = 119):
self.config = model_config
self.n_species = n_species

def build_basis_function(self):
basis_fn = GaussianBasis(
n_basis=self.config["n_basis"],
r_min=self.config["r_min"],
r_max=self.config["r_max"],
dtype=self.config["descriptor_dtype"],
)
return basis_fn

def build_radial_function(self):
basis_fn = self.build_basis_function()
radial_fn = RadialFunction(
n_radial=self.config["n_radial"],
basis_fn=basis_fn,
n_species=self.n_species,
emb_init=self.config["emb_init"],
dtype=self.config["descriptor_dtype"],
)
return radial_fn

def build_descriptor(
self,
apply_mask,
):
radial_fn = self.build_radial_function()
descriptor = GaussianMomentDescriptor(
radial_fn=radial_fn,
n_contr=self.config["n_contr"],
dtype=self.config["descriptor_dtype"],
apply_mask=apply_mask,
)
return descriptor

def build_readout(self):
readout = AtomisticReadout(
units=self.config["nn"],
b_init=self.config["b_init"],
dtype=self.config["readout_dtype"],
)
return readout

def build_scale_shift(self, scale, shift):
scale_shift = PerElementScaleShift(
n_species=self.n_species,
scale=scale,
shift=shift,
dtype=self.config["scale_shift_dtype"],
)
return scale_shift

def build_atomistic_model(
self,
scale,
shift,
apply_mask,
):
descriptor = self.build_descriptor(apply_mask)
readout = self.build_readout()
scale_shift = self.build_scale_shift(scale, shift)

atomistic_model = AtomisticModel(descriptor, readout, scale_shift)
return atomistic_model

def build_energy_model(
self,
scale=1.0,
shift=0.0,
apply_mask=True,
init_box: np.array = np.array([0.0, 0.0, 0.0]),
inference_disp_fn=None,
):
atomistic_model = self.build_atomistic_model(
scale,
shift,
apply_mask,
)
corrections = []
# if self.config["use_zbl"]:
# repulsion = ZBLRepulsion(
# apply_mask=apply_mask,
# r_max=self.config["r_max"],
# )
# corrections.append(repulsion)

model = EnergyModel(
atomistic_model,
corrections=corrections,
init_box=init_box,
inference_disp_fn=inference_disp_fn,
)
return model

def build_energy_derivative_model(
self,
scale=1.0,
shift=0.0,
apply_mask=True,
init_box: np.array = np.array([0.0, 0.0, 0.0]),
inference_disp_fn=None,
):
energy_model = self.build_energy_model(
scale,
shift,
apply_mask,
init_box=init_box,
inference_disp_fn=inference_disp_fn,
)
corrections = []
if self.config["use_zbl"]:
repulsion = ZBLRepulsion(
apply_mask=apply_mask,
r_max=self.config["r_max"],
)
corrections.append(repulsion)

model = EnergyDerivativeModel(
energy_model,
corrections=corrections,
calc_stress=self.config["calc_stress"],
)
return model
80 changes: 69 additions & 11 deletions apax/nn/torch/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd

from apax.nn.torch.layers.descriptor import GaussianMomentDescriptor
from apax.nn.torch.layers.readout import AtomisticReadout
Expand Down Expand Up @@ -35,6 +36,21 @@ def forward(
return output


def get_displacement(init_box, inference_disp_fn):
if np.all(init_box < 1e-6):
# gas phase training and predicting
displacement_fn = space.free()[0]
displacement = space.map_bond(displacement_fn)
elif inference_disp_fn is None:
# for training on periodic systems
displacement = vmap(disp_fn, (0, 0, None, None), 0)
else:
mappable_displacement_fn = get_disp_fn(self.inference_disp_fn)
displacement = vmap(mappable_displacement_fn, (0, 0, None, None), 0)

return displacement


class EnergyModel(nn.Module):
def __init__(
self,
Expand All @@ -49,16 +65,7 @@ def __init__(
self.init_box = init_box
self.inference_disp_fn = inference_disp_fn

if np.all(self.init_box < 1e-6):
# gas phase training and predicting
displacement_fn = space.free()[0]
self.displacement = space.map_bond(displacement_fn)
elif self.inference_disp_fn is None:
# for training on periodic systems
self.displacement = vmap(disp_fn, (0, 0, None, None), 0)
else:
mappable_displacement_fn = get_disp_fn(self.inference_disp_fn)
self.displacement = vmap(mappable_displacement_fn, (0, 0, None, None), 0)
self.displacement = get_displacement(init_box, inference_disp_fn)

def forward(
self,
Expand Down Expand Up @@ -86,11 +93,62 @@ def forward(

# Model Core
atomic_energies = self.atomistic_model(dr_vec, Z, idx)
total_energy = fp64_sum(atomic_energies)
total_energy = torch.sum(atomic_energies, dtype=torch.float64)

# Corrections
# for correction in self.corrections:
# energy_correction = correction(dr_vec, Z, idx)
# total_energy = total_energy + energy_correction

return total_energy


class EnergyDerivativeModel(nn.Module):
def __init__(
self,
energy_model: EnergyModel = EnergyModel(),
calc_stress: bool = False,
):
super().__init__()

self.energy_model = energy_model
self.calc_stress = calc_stress


def forward(
self,
R: torch.Tensor,
Z: torch.Tensor,
neighbor: torch.Tensor,
box: torch.Tensor,
offsets: torch.Tensor,
):
R.requires_grad = True
requires_grad = [R]
if self.calc_stress:
eps = torch.zeros((3, 3), torch.float64)
eps.requires_grad = True
eps_sym = 0.5 * (eps + eps.T)
identity = torch.eye(3, dtype=torch.float64)
perturbation = identity + eps_sym
requires_grad.append(eps)
else:
perturbation = None

energy = self.energy_model(R, Z, neighbor, box, offsets, perturbation)


grads = autograd.grad(energy, requires_grad,
grad_outputs=torch.ones_like(energy),
create_graph=True)

neg_forces = grads[0]
forces = -neg_forces

prediction = {"energy": energy, "forces": forces}

if self.calc_stress:
stress = grads[-1]
prediction["stress"] = stress

return prediction

0 comments on commit 763bf14

Please sign in to comment.