diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 986f439..9aba785 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -72,6 +72,8 @@ def __init__(self, system, lambdas, num_gpus, dynamics_kwargs): self._states = _np.array(range(len(lambdas))) self._openmm_states = [None] * len(lambdas) self._openmm_volumes = [None] * len(lambdas) + self._num_proposed = _np.matrix(_np.zeros((len(lambdas), len(lambdas)))) + self._num_accepted = _np.matrix(_np.zeros((len(lambdas), len(lambdas)))) # Copy the dynamics keyword arguments. dynamics_kwargs = dynamics_kwargs.copy() @@ -185,6 +187,18 @@ def mix_states(self): _logger.debug(f"Replica {i} seeded from state {state}") self._dynamics[i]._d._omm_mols.setState(self._openmm_states[state]) + def get_proposed(self): + """ + Return the number of proposed swaps between replicas. + """ + return self._num_proposed + + def get_accepted(self): + """ + Return the number of accepted swaps between replicas. + """ + return self._num_accepted + class RepexRunner(_RunnerBase): """ @@ -263,6 +277,9 @@ def __init__(self, system, config): else: self._start_block = 0 + # Store the name of the replica exchange swap acceptance matrix. + self._repex_matrix = self._config.output_directory / "repex_matrix.txt" + from threading import Lock # Create a lock to guard the dynamics cache. @@ -412,14 +429,25 @@ def run(self): self._mix_replicas( self._config.num_lambda, energy_matrix, + self._dynamics_cache.get_proposed(), + self._dynamics_cache.get_accepted(), ) ) self._dynamics_cache.mix_states() - # Update the block number. + # This is a checkpoint cycle. if is_checkpoint: + # Update the block number. block += 1 + # Save the replica exchange swap acceptance matrix. + _np.savetxt( + self._repex_matrix, + self._dynamics_cache.get_accepted() + / self._dynamics_cache.get_proposed(), + fmt="%.5f", + ) + # Record the end time. end = time() @@ -593,7 +621,7 @@ def _assemble_results(self, results): @staticmethod @_njit - def _mix_replicas(num_replicas, energy_matrix): + def _mix_replicas(num_replicas, energy_matrix, proposed, accepted): """ Mix the replicas. @@ -628,6 +656,9 @@ def _mix_replicas(num_replicas, energy_matrix): state_i = states[replica_i] state_j = states[replica_j] + # Record that we have proposed a swap. + proposed[state_i, state_j] += 1 + # Get the energies. energy_ii = energy_matrix[replica_i, state_i] energy_jj = energy_matrix[replica_j, state_j] @@ -639,7 +670,10 @@ def _mix_replicas(num_replicas, energy_matrix): # Accept the swap and update the states. if log_p_swap >= 0 or _np.random.rand() < _np.exp(log_p_swap): + # Swap the states. states[replica_i] = state_j states[replica_j] = state_i + # Record the swap. + accepted[state_i, state_j] += 1 return states