Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiStateReporter variable pos/vel save frequency #712

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
90 changes: 63 additions & 27 deletions openmmtools/multistate/multistatereporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class MultiStateReporter(object):
analysis_particle_indices : tuple of ints, Optional. Default: () (empty tuple)
If specified, it will serialize positions and velocities for the specified particles, at every iteration, in the
reporter storage (.nc) file. If empty, no positions or velocities will be stored in this file for any atoms.
position_interval : int, default 1
the frequency at which to write positions relative to analysis
information, 0 would prevent information being written
velocity_interval : int, default 1
the frequency at which to write positions relative to analysis
information, 0 would prevent information being written

Attributes
----------
Expand All @@ -113,7 +119,10 @@ class MultiStateReporter(object):
"""
def __init__(self, storage, open_mode=None,
checkpoint_interval=50, checkpoint_storage=None,
analysis_particle_indices=()):
analysis_particle_indices=(),
position_interval=1,
velocity_interval=1,
):

# Warn that API is experimental
logger.warn('Warning: The openmmtools.multistate API is experimental and may change in future releases')
Expand All @@ -136,6 +145,9 @@ def __init__(self, storage, open_mode=None,
self._checkpoint_interval = checkpoint_interval
# Cast to tuple no mater what 1-D-like input was given
self._analysis_particle_indices = tuple(analysis_particle_indices)
self._position_interval = position_interval
self._velocity_interval = velocity_interval

if open_mode is not None:
self.open(open_mode)
# TODO: Maybe we want to expose this flag to control ovrwriting/appending
Expand Down Expand Up @@ -202,6 +214,16 @@ def checkpoint_interval(self):
"""Returns the checkpoint interval"""
return self._checkpoint_interval

@property
def position_interval(self):
"""Interval relative to energies that positions are written at"""
return self._position_interval

@property
def velocity_interval(self):
"""Interval relative to energies that velocities are written at"""
return self._velocity_interval

def storage_exists(self, skip_size=False):
"""
Check if the storage files exist on disk.
Expand Down Expand Up @@ -415,6 +437,8 @@ def _initialize_storage_file(self, ncfile, nc_name, convention):
ncfile.ConventionVersion = '0.2'
ncfile.DataUsedFor = nc_name
ncfile.CheckpointInterval = self._checkpoint_interval
ncfile.PositionInterval = self._position_interval
ncfile.VelocityInterval = self._velocity_interval

# Create and initialize the global variables
nc_last_good_iter = ncfile.createVariable('last_iteration', int, 'scalar')
Expand Down Expand Up @@ -1647,35 +1671,47 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i
write_iteration = self._calculate_checkpoint_iteration(iteration)
else:
write_iteration = iteration

# write out pos/vel - if checkpointing,
# or if interval matches desired frequency
write_pos = (storage_file == 'checkpoint' or
(self._position_interval != 0
and not (write_iteration % self._position_interval)))
write_vel = (storage_file == 'checkpoint' or
(self._velocity_interval != 0
and not (write_iteration % self._velocity_interval)))

# Write the sampler state if we are on the checkpoint interval OR if told to ignore the interval
if write_iteration is not None:
# Store sampler states.
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic:
if write_pos:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
positions = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
# Store positions in memory first
x = sampler_state.positions / unit.nanometers
positions[replica_index, :, :] = x[:, :]
# Store positions
storage.variables['positions'][write_iteration, :, :, :] = positions

if write_vel:
# Create a numpy array to avoid making multiple (possibly inefficient) calls to netCDF assignments
velocities = np.zeros([n_replicas, n_particles, 3])
for replica_index, sampler_state in enumerate(sampler_states):
if sampler_state._unitless_velocities is not None:
# Store velocities in memory first
x = sampler_state.velocities / (unit.nanometer/unit.picoseconds) # _unitless_velocities
velocities[replica_index, :, :] = x[:, :]
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic and write_pos:
# Store box vectors and volume.
# Allocate whole write to memory first
box_vectors = np.zeros([n_replicas, 3, 3])
Expand Down
Loading