Skip to content

Commit

Permalink
Rework mixing scheme.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 31, 2024
1 parent 940b9a3 commit c20118a
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,18 @@ def set_states(self, states):
states: np.ndarray
The new states.
"""
self._old_states = self._states
self._states = states

def mix_states(self):
"""
Mix the states of the dynamics objects.
"""
# Mix the states.
for i, (old_state, new_state) in enumerate(zip(self._old_states, self._states)):
for i, state in enumerate(self._states):
# The state has changed.
if old_state != new_state:
_logger.debug(
f"Replica {i} changed state from {old_state} to {new_state}"
)
self._dynamics[i]._d._omm_mols.setState(self._openmm_states[new_state])
if i != state:
_logger.debug(f"Replica {i} changed state to {state}")
self._dynamics[i]._d._omm_mols.setState(self._openmm_states[state])


class RepexRunner(_RunnerBase):
Expand Down Expand Up @@ -394,7 +391,6 @@ def run(self):
self._mix_replicas(
self._config.num_lambda,
energy_matrix,
self._dynamics_cache._states,
)
)
self._dynamics_cache.mix_states()
Expand Down Expand Up @@ -567,7 +563,7 @@ def _assemble_results(self, results):

@staticmethod
@_njit
def _mix_replicas(num_replicas, energy_matrix, states):
def _mix_replicas(num_replicas, energy_matrix):
"""
Mix the replicas.
Expand All @@ -577,25 +573,22 @@ def _mix_replicas(num_replicas, energy_matrix, states):
num_replicas: int
The number of replicas.
num_attempts: int
The number of attempts to make.
energy_matrix: np.ndarray
The energy matrix for the replicas.
states: np.ndarray
The current state for each replica.
Returns
-------
states: np.ndarray
The new states.
"""

# Copy the states.
states = states.copy()
# Adapted from OpenMMTools: https://github.com/choderalab/openmmtools

# Set the states to the initial order.
states = _np.arange(num_replicas)

# Attempt swaps.
for swap in range(num_replicas**3):
# Choose two replicas to swap.
replica_i = _np.random.randint(num_replicas)
Expand All @@ -614,7 +607,7 @@ def _mix_replicas(num_replicas, energy_matrix, states):
# Compute the log probability of the swap.
log_p_swap = -(energy_ij + energy_ji) + energy_ii + energy_jj

# Accept or reject the swap.
# Accept the swap and update the states.
if log_p_swap >= 0 or _np.random.rand() < _np.exp(log_p_swap):
states[replica_i] = state_j
states[replica_j] = state_i
Expand Down

0 comments on commit c20118a

Please sign in to comment.