Skip to content

Commit

Permalink
Merge pull request #44 from OpenBioSim/fix_ghost_sigmas
Browse files Browse the repository at this point in the history
Add support for writing energy components to file and preserving the LJ sigma parameter for ghost atoms
  • Loading branch information
lohedges authored May 30, 2024
2 parents 7f1fa1b + 8cfc3c2 commit a1247ee
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 22 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ simulations. Built on top of [Sire](https://github.com/OpenBioSim/sire) and [Ope
First create a conda environment using the provided environment file:

```
mamba create -f environment.yaml
conda env create -f environment.yaml
```

(We recommend using [Miniforge](https://github.com/conda-forge/miniforge).)

Now install `somd2` into the environment:

```
mamba activate somd2
conda activate somd2
pip install --editable .
```

Expand Down
16 changes: 16 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
overwrite=False,
somd1_compatibility=False,
pert_file=None,
save_energy_components=False,
):
"""
Constructor.
Expand Down Expand Up @@ -281,6 +282,10 @@ def __init__(
pert_file: str
The path to a SOMD1 perturbation file to apply to the reference system.
When set, this will automatically set 'somd1_compatibility' to True.
save_energy_components: bool
Whether to save the energy contribution for each force when checkpointing.
This is useful when debugging crashes.
"""

# Setup logger before doing anything else
Expand Down Expand Up @@ -327,6 +332,7 @@ def __init__(
self.restart = restart
self.somd1_compatibility = somd1_compatibility
self.pert_file = pert_file
self.save_energy_components = save_energy_components

self.write_config = write_config

Expand Down Expand Up @@ -1201,6 +1207,16 @@ def pert_file(self, pert_file):
if pert_file is not None:
self._somd1_compatibility = True

@property
def save_energy_components(self):
return self._save_energy_components

@save_energy_components.setter
def save_energy_components(self, save_energy_components):
if not isinstance(save_energy_components, bool):
raise ValueError("'save_energy_components' must be of type 'bool'")
self._save_energy_components = save_energy_components

@property
def output_directory(self):
return self._output_directory
Expand Down
54 changes: 53 additions & 1 deletion src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self._config.restart,
)

self._nrg_sample = 0
self._nrg_file = "energy_components.txt"

@staticmethod
def create_filenames(lambda_array, lambda_value, output_directory, restart=False):
# Create incremental file name for current restart.
Expand All @@ -153,6 +156,7 @@ def increment_filename(base_filename, suffix):
filenames["energy_traj"] = f"energy_traj_{lam}.parquet"
filenames["trajectory"] = f"traj_{lam}.dcd"
filenames["trajectory_chunk"] = f"traj_{lam}_"
filenames["energy_components"] = f"energy_components_{lam}.txt"
if restart:
filenames["config"] = increment_filename("config", "yaml")
else:
Expand Down Expand Up @@ -371,7 +375,7 @@ def generate_lam_vals(lambda_base, increment):
)

if self._config.checkpoint_frequency.value() > 0.0:
# Calculate the number of blocks and the remaineder time.
# Calculate the number of blocks and the remainder time.
frac = (
self._config.runtime.value() / self._config.checkpoint_frequency.value()
)
Expand Down Expand Up @@ -409,6 +413,10 @@ def generate_lam_vals(lambda_base, increment):

# Checkpoint.
try:
# Save the energy contribution for each force.
if self._config.save_energy_components:
self._save_energy_components()

# Set to the current block number if this is a restart.
if x == 0:
x = self._current_block
Expand Down Expand Up @@ -584,3 +592,47 @@ def get_timing(self):

def _cleanup(self):
del self._dyn

def _save_energy_components(self):

from copy import deepcopy
import openmm

# Get the current context and system.
context = self._dyn._d._omm_mols
system = deepcopy(context.getSystem())

# Add each force to a unique group.
for i, f in enumerate(system.getForces()):
f.setForceGroup(i)

# Create a new context.
new_context = openmm.Context(system, deepcopy(context.getIntegrator()))
new_context.setPositions(context.getState(getPositions=True).getPositions())

header = f"{'# Sample':>10}"
record = f"{self._nrg_sample:>10}"

# Process the records.
for i, f in enumerate(system.getForces()):
state = new_context.getState(getEnergy=True, groups={i})
header += f"{f.getName():>25}"
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>25.2f}"

# Write to file.
if self._nrg_sample == 0:
with open(
self._config.output_directory / self._filenames["energy_components"],
"w",
) as f:
f.write(header + "\n")
f.write(record + "\n")
else:
with open(
self._config.output_directory / self._filenames["energy_components"],
"a",
) as f:
f.write(record + "\n")

# Increment the sample number.
self._nrg_sample += 1
9 changes: 7 additions & 2 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, system, config):

_logger.info("Applying SOMD1 perturbation compatibility.")
self._system = _make_compatible(self._system)
self._system = _morph.link_to_reference(self._system)

# Next, swap the water topology so that it is in AMBER format.

Expand Down Expand Up @@ -151,12 +152,14 @@ def __init__(self, system, config):

# Only check for light atoms by the maxium end state mass if running
# in SOMD1 compatibility mode. Ghost atoms are considered light when
# adding bond constraints.
# adding bond constraints. Also fix the LJ sigma for ghost atoms so
# it isn't scaled to zero.
self._config._extra_args["ghosts_are_light"] = True
self._config._extra_args["check_for_h_by_max_mass"] = True
self._config._extra_args["check_for_h_by_mass"] = False
self._config._extra_args["check_for_h_by_element"] = False
self._config._extra_args["check_for_h_by_ambertype"] = False
self._config._extra_args["fix_ghost_sigmas"] = True

# Check for a periodic space.
self._check_space()
Expand Down Expand Up @@ -969,5 +972,7 @@ def _run(sim, is_restart=False):
filename=self._fnames[lambda_value]["energy_traj"],
)
del system
_logger.success(f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1")
_logger.success(
f"{_lam_sym} = {lambda_value} complete, speed = {speed:.2f} ns day-1"
)
return True
93 changes: 76 additions & 17 deletions src/somd2/runner/_somd1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def _make_compatible(system):
except KeyError:
raise KeyError("No perturbable molecules in the system")

# Store a dummy element.
dummy = _SireMol.Element("Xx")
# Store a dummy element and ambertype.
element_dummy = _SireMol.Element("Xx")
ambertype_dummy = "du"

for mol in pert_mols:
# Store the molecule info.
Expand All @@ -69,9 +70,43 @@ def _make_compatible(system):
# Get an editable version of the molecule.
edit_mol = mol.edit()

##########################
# First process the bonds.
##########################
##################################
# First fix zero LJ sigmas values.
##################################

# Create a null LJParameter.
null_lj = _SireMM.LJParameter()

for atom in mol.atoms():
# Get the end state LJ sigma values.
lj0 = atom.property("LJ0")
lj1 = atom.property("LJ1")

# Lambda = 0 state has a zero sigma value.
if lj0.sigma() == null_lj.sigma():
# Use the sigma value from the lambda = 1 state.
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ0", _SireMM.LJParameter(lj1.sigma(), lj0.epsilon())
)
.molecule()
)

# Lambda = 1 state has a zero sigma value.
if lj1.sigma() == null_lj.sigma():
# Use the sigma value from the lambda = 0 state.
edit_mol = (
edit_mol.atom(atom.index())
.set_property(
"LJ1", _SireMM.LJParameter(lj0.sigma(), lj1.epsilon())
)
.molecule()
)

########################
# Now process the bonds.
########################

new_bonds0 = _SireMM.TwoAtomFunctions(mol.info())
new_bonds1 = _SireMM.TwoAtomFunctions(mol.info())
Expand Down Expand Up @@ -534,17 +569,26 @@ def _has_dummy(mol, idxs, is_lambda1=False):
Whether a dummy atom is present.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

dummy = _SireMol.Element(0)
element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Check whether an of the atoms is a dummy.
for idx in idxs:
if mol.atom(idx).property(prop) == dummy:
if (
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
):
return True

return False
Expand Down Expand Up @@ -573,21 +617,36 @@ def _is_dummy(mol, idxs, is_lambda1=False):
Whether each atom is a dummy.
"""

# Set the element property associated with the end state.
# We need to check by ambertype too since this molecule may have been
# created via sire.morph.create_from_pertfile, in which case the element
# property will have been set to the end state with the largest mass, i.e.
# may no longer by a dummy.
if is_lambda1:
prop = "element1"
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
prop = "element0"
element_prop = "element0"
ambertype_prop = "ambertype0"

# Store a dummy element.
dummy = _SireMol.Element(0)
if is_lambda1:
element_prop = "element1"
ambertype_prop = "ambertype1"
else:
element_prop = "element0"
ambertype_prop = "ambertype0"

element_dummy = _SireMol.Element(0)
ambertype_dummy = "du"

# Initialise a list to store the state of each atom.
is_dummy = []

# Check whether each of the atoms is a dummy.
for idx in idxs:
is_dummy.append(mol.atom(idx).property(prop) == dummy)
is_dummy.append(
mol.atom(idx).property(element_prop) == element_dummy
or mol.atom(idx).property(ambertype_prop) == ambertype_dummy
)

return is_dummy

Expand Down Expand Up @@ -622,7 +681,7 @@ def _apply_pert(system, pert_file):
from sire import morph as _morph

# Get the non-water molecules in the system.
non_waters = system["not water"]
non_waters = system["not water"].molecules()

# Try to apply the perturbation to each non-water molecule.
is_pert = False
Expand Down

0 comments on commit a1247ee

Please sign in to comment.