diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index c568111..62fbded 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -277,8 +277,8 @@ def run(self): from math import ceil from time import time - from concurrent.futures import ThreadPoolExecutor as _ThreadPoolExecutor - from itertools import repeat as _repeat + from concurrent.futures import ThreadPoolExecutor + from itertools import repeat # Record the start time. start = time() @@ -320,8 +320,8 @@ def run(self): num_blocks = 1 rem = 0 - # Work out the required number of executors. - executors = ceil( + # Work out the required number of batches. + num_batches = ceil( self._config.num_lambda / (self._num_gpus * self._config.oversubscription_factor) ) @@ -331,19 +331,29 @@ def run(self): # Minimise at each lambda value. if self._config.minimise: - with _ThreadPoolExecutor(max_workers=self._num_gpus) as executor: - try: - for result, index, exception in executor.map( - self._minimise, replica_list - ): - if not result: - _logger.error( - f"Minimisation failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {exception}" - ) - raise exception - except KeyboardInterrupt: - _logger.error("Minimisation cancelled. Exiting.") - exit(1) + # Run minimisation for each replica, making sure only each GPU is only + # oversubscribed by a factor of self._config.oversubscription_factor. + for i in range(num_batches): + with ThreadPoolExecutor() as executor: + try: + for result, index, exception in executor.map( + self._minimise, + replica_list[ + i + * self._num_gpus + * self._config.oversubscription_factor : (i + 1) + * self._num_gpus + * self._config.oversubscription_factor + ], + ): + if not result: + _logger.error( + f"Minimisation failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {exception}" + ) + raise exception + except KeyboardInterrupt: + _logger.error("Minimisation cancelled. Exiting.") + exit(1) # Current block number. block = 0 @@ -369,7 +379,7 @@ def run(self): # Run a dynamics block for each replica, making sure only each GPU is only # oversubscribed by a factor of self._config.oversubscription_factor. - for j in range(executors): + for j in range(num_batches): replicas = replica_list[ j * self._num_gpus @@ -377,18 +387,16 @@ def run(self): * self._num_gpus * self._config.oversubscription_factor ] - with _ThreadPoolExecutor( - max_workers=self._num_gpus * self._config.oversubscription_factor - ) as executor: + with ThreadPoolExecutor() as executor: try: for result, index, energies in executor.map( self._run_block, replicas, - _repeat(self._lambda_values), - _repeat(is_checkpoint), - _repeat(i == cycles - 1), - _repeat(block), - _repeat(num_blocks + int(rem > 0)), + repeat(self._lambda_values), + repeat(is_checkpoint), + repeat(i == cycles - 1), + repeat(block), + repeat(num_blocks + int(rem > 0)), ): if not result: _logger.error(