diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 7a2c329..1453cec 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -336,11 +336,23 @@ def run(self): # Create the replica list. replica_list = list(range(self._config.num_lambda)) - # Minimise at each lambda value. This is currently done in serial due to a - # threading issue with the Sire OpenMM minimiser. + # Minimise at each lambda value. if self._config.minimise: - for i in range(self._config.num_lambda): - self._minimise(i) + for i in range(num_batches): + with ThreadPoolExecutor(max_workers=num_workers) as executor: + try: + for success, index, e in executor.map( + self._minimise, + replica_list[i * num_workers : (i + 1) * num_workers], + ): + if not success: + _logger.error( + f"Minimisation failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {e}" + ) + raise e + except KeyboardInterrupt: + _logger.error("Minimisation cancelled. Exiting.") + exit(1) # Current block number. block = 0 @@ -545,12 +557,23 @@ def _minimise(self, index): """ _logger.info(f"Minimising at {_lam_sym} = {self._lambda_values[index]:.5f}") + # Note: For now we minimise with the LocalEnergyMinimizer in OpenMM + # since dynamics.minimise() is not thread safe. + try: + from openmm import LocalEnergyMinimizer + # Get the dynamics object. dynamics = self._dynamics_cache.get(index) - # Minimise the system. - dynamics.minimise(timeout=self._config.timeout) + # Get the context. + context = dynamics.context() + + # Minimise. + LocalEnergyMinimizer.minimize(context) + + # Clear the internal dynamics state. + dynamics._d._clear_state() except Exception as e: return False, index, e