Skip to content

Commit

Permalink
Minimise in serial and add mutex around checkpointing.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Nov 4, 2024
1 parent 51ecbd5 commit 7bb9df8
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def __init__(self, system, lambdas, num_gpus, dynamics_kwargs):
# Append the dynamics object.
self._dynamics.append(dynamics)

_logger.info(f"Created dynamics object for lambda {lam:.5f} on device {device}")
_logger.info(
f"Created dynamics object for lambda {lam:.5f} on device {device}"
)

def get(self, index):
"""
Expand Down Expand Up @@ -261,6 +263,11 @@ def __init__(self, system, config):
else:
self._start_block = 0

from threading import Lock

# Create a lock to guard the dynamics cache.
self._lock = Lock()

def __str__(self):
"""Return a string representation of the object."""
return f"RepexRunner(system={self._system}, config={self._config})"
Expand Down Expand Up @@ -329,31 +336,11 @@ def run(self):
# Create the replica list.
replica_list = list(range(self._config.num_lambda))

# Minimise at each lambda value.
# Minimise at each lambda value. This is currently done in serial due to a
# limitation in OpenMM.
if self._config.minimise:
# Run minimisation for each replica, making sure only each GPU is only
# oversubscribed by a factor of self._config.oversubscription_factor.
for i in range(num_batches):
with ThreadPoolExecutor() as executor:
try:
for result, index, exception in executor.map(
self._minimise,
replica_list[
i
* self._num_gpus
* self._config.oversubscription_factor : (i + 1)
* self._num_gpus
* self._config.oversubscription_factor
],
):
if not result:
_logger.error(
f"Minimisation failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {exception}"
)
raise exception
except KeyboardInterrupt:
_logger.error("Minimisation cancelled. Exiting.")
exit(1)
for i in range(self._config.num_lambda):
self._minimise(i)

# Current block number.
block = 0
Expand Down Expand Up @@ -515,9 +502,10 @@ def _run_block(
speed = dynamics.time_speed()

# Checkpoint.
self._checkpoint(
system, index, block, speed, is_final_block=is_final_block
)
with self._lock:
self._checkpoint(
system, index, block, speed, is_final_block=is_final_block
)

_logger.info(
f"Finished block {block+1} of {self._start_block + num_blocks} "
Expand Down

0 comments on commit 7bb9df8

Please sign in to comment.