Skip to content

Commit

Permalink
Add functionality to save the repex swap acceptance matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 19, 2024
1 parent 5793db2 commit 2f95b1f
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, system, lambdas, num_gpus, dynamics_kwargs):
self._states = _np.array(range(len(lambdas)))
self._openmm_states = [None] * len(lambdas)
self._openmm_volumes = [None] * len(lambdas)
self._num_proposed = _np.matrix(_np.zeros((len(lambdas), len(lambdas))))
self._num_accepted = _np.matrix(_np.zeros((len(lambdas), len(lambdas))))

# Copy the dynamics keyword arguments.
dynamics_kwargs = dynamics_kwargs.copy()
Expand Down Expand Up @@ -185,6 +187,18 @@ def mix_states(self):
_logger.debug(f"Replica {i} seeded from state {state}")
self._dynamics[i]._d._omm_mols.setState(self._openmm_states[state])

def get_proposed(self):
"""
Return the number of proposed swaps between replicas.
"""
return self._num_proposed

def get_accepted(self):
"""
Return the number of accepted swaps between replicas.
"""
return self._num_accepted


class RepexRunner(_RunnerBase):
"""
Expand Down Expand Up @@ -263,6 +277,9 @@ def __init__(self, system, config):
else:
self._start_block = 0

# Store the name of the replica exchange swap acceptance matrix.
self._repex_matrix = self._config.output_directory / "repex_matrix.txt"

from threading import Lock

# Create a lock to guard the dynamics cache.
Expand Down Expand Up @@ -412,14 +429,25 @@ def run(self):
self._mix_replicas(
self._config.num_lambda,
energy_matrix,
self._dynamics_cache.get_proposed(),
self._dynamics_cache.get_accepted(),
)
)
self._dynamics_cache.mix_states()

# Update the block number.
# This is a checkpoint cycle.
if is_checkpoint:
# Update the block number.
block += 1

# Save the replica exchange swap acceptance matrix.
_np.savetxt(
self._repex_matrix,
self._dynamics_cache.get_accepted()
/ self._dynamics_cache.get_proposed(),
fmt="%.5f",
)

# Record the end time.
end = time()

Expand Down Expand Up @@ -593,7 +621,7 @@ def _assemble_results(self, results):

@staticmethod
@_njit
def _mix_replicas(num_replicas, energy_matrix):
def _mix_replicas(num_replicas, energy_matrix, proposed, accepted):
"""
Mix the replicas.
Expand Down Expand Up @@ -628,6 +656,9 @@ def _mix_replicas(num_replicas, energy_matrix):
state_i = states[replica_i]
state_j = states[replica_j]

# Record that we have proposed a swap.
proposed[state_i, state_j] += 1

# Get the energies.
energy_ii = energy_matrix[replica_i, state_i]
energy_jj = energy_matrix[replica_j, state_j]
Expand All @@ -639,7 +670,10 @@ def _mix_replicas(num_replicas, energy_matrix):

# Accept the swap and update the states.
if log_p_swap >= 0 or _np.random.rand() < _np.exp(log_p_swap):
# Swap the states.
states[replica_i] = state_j
states[replica_j] = state_i
# Record the swap.
accepted[state_i, state_j] += 1

return states

0 comments on commit 2f95b1f

Please sign in to comment.