Skip to content

Commit

Permalink
Add support for using a custom list of lambda values. [closes #41]
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Apr 24, 2024
1 parent 34f5bfa commit a76d0a7
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
43 changes: 42 additions & 1 deletion src/somd2/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = ["Config"]


from collections.abc import Iterable as _Iterable
from typing import Iterable as _Iterable
from openmm import Platform as _Platform
from pathlib import Path as _Path

Expand Down Expand Up @@ -72,6 +72,11 @@ class Config:
"log_level": [level.lower() for level in _logger._core.levels],
}

# A dictionary of nargs for the various options.
_nargs = {
"lambda_values": "+",
}

def __init__(
self,
log_level="info",
Expand All @@ -86,6 +91,7 @@ def __init__(
cutoff="7.5A",
h_mass_factor=1.5,
num_lambda=11,
lambda_values=None,
lambda_schedule="standard_morph",
charge_scale_factor=0.2,
swap_end_states=False,
Expand Down Expand Up @@ -153,6 +159,10 @@ def __init__(
num_lambda: int
Number of lambda windows to use.
lambda_values: [float]
A list of lambda values. When specified, this takes precedence over
the 'num_lambda' option.
lambda_schedule: str
Lambda schedule to use for alchemical free energy simulations.
Expand Down Expand Up @@ -281,6 +291,7 @@ def __init__(
self.h_mass_factor = h_mass_factor
self.timestep = timestep
self.num_lambda = num_lambda
self.lambda_values = lambda_values
self.lambda_schedule = lambda_schedule
self.charge_scale_factor = charge_scale_factor
self.swap_end_states = swap_end_states
Expand Down Expand Up @@ -616,6 +627,29 @@ def num_lambda(self, num_lambda):
raise ValueError("'num_lambda' must be an integer")
self._num_lambda = num_lambda

@property
def lambda_values(self):
return self._lambda_values

@lambda_values.setter
def lambda_values(self, lambda_values):
if lambda_values is not None:
if not isinstance(lambda_values, _Iterable):
raise ValueError("'lambda_values' must be an iterable")
try:
lambda_values = [float(x) for x in lambda_values]
except:
raise ValueError("'lambda_values' must be an iterable of floats")

if not all(0 <= x <= 1 for x in lambda_values):
raise ValueError(
"All entries in 'lambda_values' must be between 0 and 1"
)

self._num_lambda = len(lambda_values)

self._lambda_values = lambda_values

@property
def lambda_schedule(self):
return self._lambda_schedule
Expand Down Expand Up @@ -1273,6 +1307,12 @@ def _create_parser(cls):
# Get the type of the parameter. If None, then use str.
typ = str if params[param].default is None else type(params[param].default)

# Get the nargs for the parameter.
if param in cls._nargs:
nargs = cls._nargs[param]
else:
nargs = None

# This parameter has choices.
if param in cls._choices:
parser.add_argument(
Expand All @@ -1297,6 +1337,7 @@ def _create_parser(cls):
parser.add_argument(
f"--{cli_param}",
type=typ,
nargs=nargs,
default=params[param].default,
help=help[param],
required=False,
Expand Down
18 changes: 9 additions & 9 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,8 @@ def __init__(self, system, config):
The perturbable system to be simulated. This can be either a path
to a stream file, or a Sire system object.
num_lambda: int
The number of lambda windows to be simulated.
platform: str
The platform to be used for simulations.
config: :class: `Config <somd2.config.Config>`
The configuration options for the simulation.
"""

if not isinstance(system, (str, _System)):
Expand Down Expand Up @@ -169,10 +166,13 @@ def __init__(self, system, config):
self._check_end_state_constraints()

# Set the lambda values.
self._lambda_values = [
round(i / (self._config.num_lambda - 1), 5)
for i in range(0, self._config.num_lambda)
]
if self._config.lambda_values:
self._lambda_values = self._config.lambda_values
else:
self._lambda_values = [
round(i / (self._config.num_lambda - 1), 5)
for i in range(0, self._config.num_lambda)
]

# Work out the current hydrogen mass factor.
h_mass_factor, has_hydrogen = self._get_h_mass_factor(self._system)
Expand Down

0 comments on commit a76d0a7

Please sign in to comment.