Skip to content

Commit

Permalink
Add support for sampling energies at different lambda values.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Apr 24, 2024
1 parent a76d0a7 commit 843c63b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 6 deletions.
29 changes: 29 additions & 0 deletions src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Config:
# A dictionary of nargs for the various options.
_nargs = {
"lambda_values": "+",
"lambda_energy": "+",
}

def __init__(
Expand All @@ -92,6 +93,7 @@ def __init__(
h_mass_factor=1.5,
num_lambda=11,
lambda_values=None,
lambda_energy=None,
lambda_schedule="standard_morph",
charge_scale_factor=0.2,
swap_end_states=False,
Expand Down Expand Up @@ -163,6 +165,11 @@ def __init__(
A list of lambda values. When specified, this takes precedence over
the 'num_lambda' option.
lambda_energy: [float]
A list of lambda values at which to output energy data. If not set,
then this will be set to the same as 'lambda_values', or the values
defined by 'num_lambda' if 'lambda_values' is not set.
lambda_schedule: str
Lambda schedule to use for alchemical free energy simulations.
Expand Down Expand Up @@ -292,6 +299,7 @@ def __init__(
self.timestep = timestep
self.num_lambda = num_lambda
self.lambda_values = lambda_values
self.lambda_energy = lambda_energy
self.lambda_schedule = lambda_schedule
self.charge_scale_factor = charge_scale_factor
self.swap_end_states = swap_end_states
Expand Down Expand Up @@ -650,6 +658,27 @@ def lambda_values(self, lambda_values):

self._lambda_values = lambda_values

@property
def lambda_energy(self):
return self._lambda_energy

@lambda_energy.setter
def lambda_energy(self, lambda_energy):
if lambda_energy is not None:
if not isinstance(lambda_energy, _Iterable):
raise ValueError("'lambda_energy' must be an iterable")
try:
lambda_energy = [float(x) for x in lambda_energy]
except:
raise ValueError("'lambda_energy' must be an iterable of floats")

if not all(0 <= x <= 1 for x in lambda_energy):
raise ValueError(
"All entries in 'lambda_energy' must be between 0 and 1"
)

self._lambda_energy = lambda_energy

@property
def lambda_schedule(self):
return self._lambda_schedule
Expand Down
18 changes: 13 additions & 5 deletions src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
system,
lambda_val,
lambda_array,
lambda_energy,
config,
increment=0.001,
device=None,
Expand All @@ -61,8 +62,11 @@ def __init__(
Lambda value for the simulation
lambda_array : list
List of lambda values to be used for perturbation, if none won't return
reduced perturbed energies
List of lambda values to be used for simulation.
lambda_energy: list
List of lambda values to be used for sampling energies. If None, then we
won't return reduced perturbed energies.
increment : float
Increment of lambda value - used for calculating the gradient
Expand Down Expand Up @@ -107,18 +111,22 @@ def __init__(

self._lambda_val = lambda_val
self._lambda_array = lambda_array
self._lambda_energy = lambda_energy
self._increment = increment
self._device = device
self._has_space = has_space
self._filenames = self.create_filenames(
self._lambda_array,
self._lambda_val,
self._lambda_energy,
self._config.output_directory,
self._config.restart,
)

@staticmethod
def create_filenames(lambda_array, lambda_value, output_directory, restart=False):
def create_filenames(
lambda_array, lambda_value, lambda_energy, output_directory, restart=False
):
# Create incremental file name for current restart.
def increment_filename(base_filename, suffix):
file_number = 0
Expand Down Expand Up @@ -348,10 +356,10 @@ def generate_lam_vals(lambda_base, increment):
# Work out the lambda values for finite-difference gradient analysis.
self._lambda_grad = generate_lam_vals(self._lambda_val, self._increment)

if self._lambda_array is None:
if self._lambda_energy is None:
lam_arr = self._lambda_grad
else:
lam_arr = self._lambda_array + self._lambda_grad
lam_arr = self._lambda_energy + self._lambda_grad

_logger.info(f"Running dynamics at {_lam_sym} = {self._lambda_val}")

Expand Down
16 changes: 15 additions & 1 deletion src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def __init__(self, system, config):
for i in range(0, self._config.num_lambda)
]

# Set the lambda energy list.
if self._config.lambda_energy is not None:
self._lambda_energy = self._config.lambda_energy
else:
self._lambda_energy = self._config.lambda_values

# Work out the current hydrogen mass factor.
h_mass_factor, has_hydrogen = self._get_h_mass_factor(self._system)

Expand Down Expand Up @@ -312,6 +318,7 @@ def _check_directory(self):
files = Dynamics.create_filenames(
self._lambda_values,
lambda_value,
self._lambda_energy,
self._config.output_directory,
self._config.restart,
)
Expand Down Expand Up @@ -679,6 +686,7 @@ def _initialise_simulation(self, system, lambda_value, device=None):
system,
lambda_val=lambda_value,
lambda_array=self._lambda_values,
lambda_energy=self._lambda_energy,
config=self._config,
device=device,
has_space=self._has_space,
Expand Down Expand Up @@ -939,6 +947,12 @@ def _run(sim, is_restart=False):

from somd2 import __version__, _sire_version, _sire_revisionid

# Add the current lambda value to the list of lambda values and sort.
lambda_array = self._lambda_energy.copy()
if lambda_value not in lambda_array:
lambda_array.append(lambda_value)
lambda_array = sorted(lambda_array)

# Write final dataframe for the system to the energy trajectory file.
# Note that sire s3 checkpoint files contain energy trajectory data, so this works even for restarts.
_ = _dataframe_to_parquet(
Expand All @@ -948,7 +962,7 @@ def _run(sim, is_restart=False):
"somd2 version": __version__,
"sire version": f"{_sire_version}+{_sire_revisionid}",
"lambda": str(lambda_value),
"lambda_array": self._lambda_values,
"lambda_array": lambda_array,
"lambda_grad": lambda_grad,
"speed": speed,
"temperature": str(self._config.temperature.value()),
Expand Down
90 changes: 90 additions & 0 deletions tests/runner/test_lambda_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from pathlib import Path

import tempfile
import pytest

import sire as sr

from somd2.runner import Runner
from somd2.config import Config
from somd2.io import *


def test_lambda_values(ethane_methanol):
"""
Validate that a simulation can be run with a custom list of lambda values.
"""

with tempfile.TemporaryDirectory() as tmpdir:
mols = ethane_methanol.clone()

config = {
"runtime": "12fs",
"restart": False,
"output_directory": tmpdir,
"energy_frequency": "4fs",
"checkpoint_frequency": "4fs",
"frame_frequency": "4fs",
"platform": "CPU",
"max_threads": 1,
"lambda_values": [0.0, 0.5, 1.0],
}

# Instantiate a runner using the config defined above.
runner = Runner(mols, Config(**config))

# Run the simulation.
runner.run()

# Load the energy trajectory.
energy_traj, meta = parquet_to_dataframe(Path(tmpdir) / "energy_traj_0.parquet")

# Make sure the lambda_array in the metadata is correct. This is the
# lambda_values list in the config.
assert meta["lambda_array"] == [0.0, 0.5, 1.0]

# Make sure the second dimension of the energy trajectory is the correct
# size. This is one for the current lambda value, one for its gradient,
# and two for the additional values in the lambda_values list.
assert energy_traj.shape[1] == 4


def test_lambda_energy(ethane_methanol):
"""
Validate that a simulation can sample energies at a different set of
lambda values.
"""

with tempfile.TemporaryDirectory() as tmpdir:
mols = ethane_methanol.clone()

config = {
"runtime": "12fs",
"restart": False,
"output_directory": tmpdir,
"energy_frequency": "4fs",
"checkpoint_frequency": "4fs",
"frame_frequency": "4fs",
"platform": "CPU",
"max_threads": 1,
"lambda_values": [0.0, 1.0],
"lambda_energy": [0.5],
}

# Instantiate a runner using the config defined above.
runner = Runner(mols, Config(**config))

# Run the simulation.
runner.run()

# Load the energy trajectory.
energy_traj, meta = parquet_to_dataframe(Path(tmpdir) / "energy_traj_0.parquet")

# Make sure the lambda_array in the metadata is correct. This is the
# sampled lambda plus the lambda_energy values in the config.
assert meta["lambda_array"] == [0.0, 0.5]

# Make sure the second dimension of the energy trajectory is the correct
# size. This is one for the current lambda value, one for its gradient,
# and one for the length of lambda_energy.
assert energy_traj.shape[1] == 3

0 comments on commit 843c63b

Please sign in to comment.