diff --git a/src/somd2/config/_config.py b/src/somd2/config/_config.py index 1dc4533..38120c6 100644 --- a/src/somd2/config/_config.py +++ b/src/somd2/config/_config.py @@ -26,7 +26,7 @@ __all__ = ["Config"] -from collections.abc import Iterable as _Iterable +from typing import Iterable as _Iterable from openmm import Platform as _Platform from pathlib import Path as _Path @@ -72,6 +72,11 @@ class Config: "log_level": [level.lower() for level in _logger._core.levels], } + # A dictionary of nargs for the various options. + _nargs = { + "lambda_values": "+", + } + def __init__( self, log_level="info", @@ -86,6 +91,7 @@ def __init__( cutoff="7.5A", h_mass_factor=1.5, num_lambda=11, + lambda_values=None, lambda_schedule="standard_morph", charge_scale_factor=0.2, swap_end_states=False, @@ -153,6 +159,10 @@ def __init__( num_lambda: int Number of lambda windows to use. + lambda_values: [float] + A list of lambda values. When specified, this takes precedence over + the 'num_lambda' option. + lambda_schedule: str Lambda schedule to use for alchemical free energy simulations. @@ -281,6 +291,7 @@ def __init__( self.h_mass_factor = h_mass_factor self.timestep = timestep self.num_lambda = num_lambda + self.lambda_values = lambda_values self.lambda_schedule = lambda_schedule self.charge_scale_factor = charge_scale_factor self.swap_end_states = swap_end_states @@ -616,6 +627,29 @@ def num_lambda(self, num_lambda): raise ValueError("'num_lambda' must be an integer") self._num_lambda = num_lambda + @property + def lambda_values(self): + return self._lambda_values + + @lambda_values.setter + def lambda_values(self, lambda_values): + if lambda_values is not None: + if not isinstance(lambda_values, _Iterable): + raise ValueError("'lambda_values' must be an iterable") + try: + lambda_values = [float(x) for x in lambda_values] + except: + raise ValueError("'lambda_values' must be an iterable of floats") + + if not all(0 <= x <= 1 for x in lambda_values): + raise ValueError( + "All entries in 'lambda_values' must be between 0 and 1" + ) + + self._num_lambda = len(lambda_values) + + self._lambda_values = lambda_values + @property def lambda_schedule(self): return self._lambda_schedule @@ -1273,6 +1307,12 @@ def _create_parser(cls): # Get the type of the parameter. If None, then use str. typ = str if params[param].default is None else type(params[param].default) + # Get the nargs for the parameter. + if param in cls._nargs: + nargs = cls._nargs[param] + else: + nargs = None + # This parameter has choices. if param in cls._choices: parser.add_argument( @@ -1297,6 +1337,7 @@ def _create_parser(cls): parser.add_argument( f"--{cli_param}", type=typ, + nargs=nargs, default=params[param].default, help=help[param], required=False, diff --git a/src/somd2/runner/_runner.py b/src/somd2/runner/_runner.py index 934bb8f..1a1efde 100644 --- a/src/somd2/runner/_runner.py +++ b/src/somd2/runner/_runner.py @@ -62,11 +62,8 @@ def __init__(self, system, config): The perturbable system to be simulated. This can be either a path to a stream file, or a Sire system object. - num_lambda: int - The number of lambda windows to be simulated. - - platform: str - The platform to be used for simulations. + config: :class: `Config ` + The configuration options for the simulation. """ if not isinstance(system, (str, _System)): @@ -169,10 +166,13 @@ def __init__(self, system, config): self._check_end_state_constraints() # Set the lambda values. - self._lambda_values = [ - round(i / (self._config.num_lambda - 1), 5) - for i in range(0, self._config.num_lambda) - ] + if self._config.lambda_values: + self._lambda_values = self._config.lambda_values + else: + self._lambda_values = [ + round(i / (self._config.num_lambda - 1), 5) + for i in range(0, self._config.num_lambda) + ] # Work out the current hydrogen mass factor. h_mass_factor, has_hydrogen = self._get_h_mass_factor(self._system)