From c76084e48ceae9bf965cd94009c7a1fb3f70f833 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Wed, 19 Apr 2023 22:22:32 -0700 Subject: [PATCH 01/11] Cleanup contexts upon calculation completion, failure This adds calls to `ContextCache.empty` at the end of repex execution, or in the case of execution failure, to avoid leaving behind stale contexts. This can be important for long-running processes executing multiple `ProtocolDAG`s in sequence. --- .../protocols/openmm_rfe/equil_rfe_methods.py | 199 +++++++++--------- 1 file changed, 103 insertions(+), 96 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 8cdd13b6c..e71cf7a14 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True, # a. check timestep correctness + that # equilibration & production are divisible by n_steps - prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] stateA = self._inputs['stateA'] stateB = self._inputs['stateB'] mapping = self._inputs['ligandmapping'] - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings - thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings - alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings - system_settings: SystemSettings = prototol_settings.system_settings - solvation_settings: SolvationSettings = prototol_settings.solvation_settings - sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings - sim_settings: SimulationSettings = prototol_settings.simulation_settings - timestep = prototol_settings.integrator_settings.timestep - mc_steps = prototol_settings.integrator_settings.n_steps.m + forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings + thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings + alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings + system_settings: SystemSettings = protocol_settings.system_settings + solvation_settings: SolvationSettings = protocol_settings.solvation_settings + sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings + sim_settings: SimulationSettings = protocol_settings.simulation_settings + timestep = protocol_settings.integrator_settings.timestep + mc_steps = protocol_settings.integrator_settings.n_steps.m # is the timestep good for the mass? if forcefield_settings.hydrogen_mass < 3.0: @@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True, if 'solvent' in stateA.components: hybrid_factory.hybrid_system.addForce( openmm.MonteCarloBarostat( - prototol_settings.thermo_settings.pressure.to(unit.bar).m, - prototol_settings.thermo_settings.temperature.m, - prototol_settings.integrator_settings.barostat_frequency.m, + protocol_settings.thermo_settings.pressure.to(unit.bar).m, + protocol_settings.thermo_settings.temperature.m, + protocol_settings.integrator_settings.barostat_frequency.m, ) ) @@ -574,9 +574,13 @@ def run(self, *, dry=False, verbose=True, # 10. Get platform and context caches platform = _rfe_utils.compute.get_openmm_platform( - prototol_settings.engine_settings.compute_platform + protocol_settings.engine_settings.compute_platform ) + # 11. Set the integrator + # a. get integrator settings + integrator_settings = protocol_settings.integrator_settings + # a. Create context caches (energy + sampler) # Note: these needs to exist on the compute node energy_context_cache = openmmtools.cache.ContextCache( @@ -587,107 +591,110 @@ def run(self, *, dry=False, verbose=True, capacity=None, time_to_live=None, platform=platform, ) - # 11. Set the integrator - # a. get integrator settings - integrator_settings = prototol_settings.integrator_settings - - # b. create langevin integrator - integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.collision_rate), - n_steps=integrator_settings.n_steps.m, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - splitting=integrator_settings.splitting - ) - - # 12. Create sampler - if sampler_settings.sampler_method.lower() == "repex": - sampler = _rfe_utils.multistate.HybridRepexSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_target_error=sampler_settings.online_analysis_target_error.m, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations - ) - elif sampler_settings.sampler_method.lower() == "sams": - sampler = _rfe_utils.multistate.HybridSAMSSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations, - flatness_criteria=sampler_settings.flatness_criteria, - gamma0=sampler_settings.gamma0, + try: + # b. create langevin integrator + integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.collision_rate), + n_steps=integrator_settings.n_steps.m, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + splitting=integrator_settings.splitting ) - elif sampler_settings.sampler_method.lower() == 'independent': - sampler = _rfe_utils.multistate.HybridMultiStateSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_target_error=sampler_settings.online_analysis_target_error.m, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + + # 12. Create sampler + if sampler_settings.sampler_method.lower() == "repex": + sampler = _rfe_utils.multistate.HybridRepexSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + ) + elif sampler_settings.sampler_method.lower() == "sams": + sampler = _rfe_utils.multistate.HybridSAMSSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations, + flatness_criteria=sampler_settings.flatness_criteria, + gamma0=sampler_settings.gamma0, + ) + elif sampler_settings.sampler_method.lower() == 'independent': + sampler = _rfe_utils.multistate.HybridMultiStateSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + ) + + else: + raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") + + sampler.setup( + n_replicas=sampler_settings.n_replicas, + reporter=reporter, + platform=platform, + lambda_protocol=lambdas, + temperature=to_openmm(protocol_settings.thermo_settings.temperature), + endstates=alchem_settings.unsampled_endstates, ) - else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") - - sampler.setup( - n_replicas=sampler_settings.n_replicas, - reporter=reporter, - platform=platform, - lambda_protocol=lambdas, - temperature=to_openmm(prototol_settings.thermo_settings.temperature), - endstates=alchem_settings.unsampled_endstates, - ) + sampler.energy_context_cache = energy_context_cache + sampler.sampler_context_cache = sampler_context_cache - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache + if not dry: # pragma: no-cover + # minimize + if verbose: + logger.info("minimizing systems") - if not dry: # pragma: no-cover - # minimize - if verbose: - logger.info("minimizing systems") + sampler.minimize(max_iterations=sim_settings.minimization_steps) - sampler.minimize(max_iterations=sim_settings.minimization_steps) + # equilibrate + if verbose: + logger.info("equilibrating systems") - # equilibrate - if verbose: - logger.info("equilibrating systems") + sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore - sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore + # production + if verbose: + logger.info("running production phase") - # production - if verbose: - logger.info("running production phase") + sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore - sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore + # calculate estimate of results from this individual unit + ana = multistate.MultiStateSamplerAnalyzer(reporter) + est, _ = ana.get_free_energy() + est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) + est = ensure_quantity(est, 'openff') - # calculate estimate of results from this individual unit - ana = multistate.MultiStateSamplerAnalyzer(reporter) - est, _ = ana.get_free_energy() - est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) - est = ensure_quantity(est, 'openff') + # close reporter when you're done + reporter.close() - # close reporter when you're done - reporter.close() + nc = shared_basepath / sim_settings.output_filename + chk = shared_basepath / sim_settings.checkpoint_storage + else: + # close reporter when you're done, prevent file handle clashes + reporter.close() - nc = shared_basepath / sim_settings.output_filename - chk = shared_basepath / sim_settings.checkpoint_storage + # clean up the reporter file + fns = [shared_basepath / sim_settings.output_filename, + shared_basepath / sim_settings.checkpoint_storage] + for fn in fns: + os.remove(fn) + finally: + energy_context_cache.empty() + sampler_context_cache.empty() + + if not dry: # pragma: no-cover return { 'nc': nc, 'last_checkpoint': chk, 'unit_estimate': est, } else: - # close reporter when you're done, prevent file handle clashes - reporter.close() - - # clean up the reporter file - fns = [shared_basepath / sim_settings.output_filename, - shared_basepath / sim_settings.checkpoint_storage] - for fn in fns: - os.remove(fn) return {'debug': {'sampler': sampler}} def _execute( From 8e8df34ea954b9f74acb5a482b06b0913003d826 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 08:40:38 -0700 Subject: [PATCH 02/11] Changes from live power hour discussion --- .../openmm_rfe/_rfe_utils/multistate.py | 17 +++--- .../protocols/openmm_rfe/equil_rfe_methods.py | 53 ++++++++++++------- 2 files changed, 44 insertions(+), 26 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index efcf85dad..2994dc43d 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -287,10 +287,13 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat context, integrator = cache.global_context_cache.get_context( thermodynamic_state, integrator ) - sampler_state.apply_to_context( - context, ignore_velocities=True - ) - openmm.LocalEnergyMinimizer.minimize( - context, maxIterations=max_iterations - ) - sampler_state.update_from_context(context) + try: + sampler_state.apply_to_context( + context, ignore_velocities=True + ) + openmm.LocalEnergyMinimizer.minimize( + context, maxIterations=max_iterations + ) + sampler_state.update_from_context(context) + finally: + del context, integrator diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index e71cf7a14..dcc174119 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -81,7 +81,13 @@ def get_estimate(self): a Quantity defined with units. """ # TODO: Check this holds up completely for SAMS. - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + # this v + dGs = [] + for pus in self.data.values(): + dGs.extend([pu.outputs['unit_estimate'] for pu in pus]) + + # not this v + #dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -91,7 +97,13 @@ def get_estimate(self): def get_uncertainty(self): """The uncertainty/error in the dG value: The std of the estimates of each independent repeat""" - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + # this v + dGs = [] + for pus in self.data.values(): + dGs.extend([pu.outputs['unit_estimate'] for pu in pus]) + + # not this v + #dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -591,18 +603,19 @@ def run(self, *, dry=False, verbose=True, capacity=None, time_to_live=None, platform=platform, ) - try: - # b. create langevin integrator - integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.collision_rate), - n_steps=integrator_settings.n_steps.m, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - splitting=integrator_settings.splitting - ) + # b. create langevin integrator + integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.collision_rate), + n_steps=integrator_settings.n_steps.m, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + splitting=integrator_settings.splitting + ) + + try: # 12. Create sampler if sampler_settings.sampler_method.lower() == "repex": sampler = _rfe_utils.multistate.HybridRepexSampler( @@ -670,24 +683,26 @@ def run(self, *, dry=False, verbose=True, est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) est = ensure_quantity(est, 'openff') - # close reporter when you're done - reporter.close() - nc = shared_basepath / sim_settings.output_filename chk = shared_basepath / sim_settings.checkpoint_storage else: - # close reporter when you're done, prevent file handle clashes - reporter.close() - # clean up the reporter file fns = [shared_basepath / sim_settings.output_filename, shared_basepath / sim_settings.checkpoint_storage] for fn in fns: os.remove(fn) finally: + # close reporter when you're done, prevent file handle clashes + reporter.close() + + # clear GPU contexts energy_context_cache.empty() sampler_context_cache.empty() + del sampler_context_cache, energy_context_cache + del integrator, sampler + + if not dry: # pragma: no-cover return { 'nc': nc, From abd3ae5abefb523ac6c469a50ee6bbef08fe3476 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 08:57:05 -0700 Subject: [PATCH 03/11] Remove changes to ProtocolResult --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index dcc174119..db914456f 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -81,13 +81,7 @@ def get_estimate(self): a Quantity defined with units. """ # TODO: Check this holds up completely for SAMS. - # this v - dGs = [] - for pus in self.data.values(): - dGs.extend([pu.outputs['unit_estimate'] for pu in pus]) - - # not this v - #dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -97,13 +91,7 @@ def get_estimate(self): def get_uncertainty(self): """The uncertainty/error in the dG value: The std of the estimates of each independent repeat""" - # this v - dGs = [] - for pus in self.data.values(): - dGs.extend([pu.outputs['unit_estimate'] for pu in pus]) - - # not this v - #dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units From d5578431e3213b4fef8c1143b14afbfc4abd5bb2 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 11:20:36 -0700 Subject: [PATCH 04/11] Addressing @ialibay review --- .../protocols/openmm_rfe/equil_rfe_methods.py | 88 ++++++++++--------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index db914456f..d16c253a1 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -602,47 +602,46 @@ def run(self, *, dry=False, verbose=True, splitting=integrator_settings.splitting ) - - try: - # 12. Create sampler - if sampler_settings.sampler_method.lower() == "repex": - sampler = _rfe_utils.multistate.HybridRepexSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_target_error=sampler_settings.online_analysis_target_error.m, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations - ) - elif sampler_settings.sampler_method.lower() == "sams": - sampler = _rfe_utils.multistate.HybridSAMSSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations, - flatness_criteria=sampler_settings.flatness_criteria, - gamma0=sampler_settings.gamma0, - ) - elif sampler_settings.sampler_method.lower() == 'independent': - sampler = _rfe_utils.multistate.HybridMultiStateSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=sampler_settings.online_analysis_interval, - online_analysis_target_error=sampler_settings.online_analysis_target_error.m, - online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations - ) - - else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") - - sampler.setup( - n_replicas=sampler_settings.n_replicas, - reporter=reporter, - platform=platform, - lambda_protocol=lambdas, - temperature=to_openmm(protocol_settings.thermo_settings.temperature), - endstates=alchem_settings.unsampled_endstates, + # 12. Create sampler + if sampler_settings.sampler_method.lower() == "repex": + sampler = _rfe_utils.multistate.HybridRepexSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations + ) + elif sampler_settings.sampler_method.lower() == "sams": + sampler = _rfe_utils.multistate.HybridSAMSSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations, + flatness_criteria=sampler_settings.flatness_criteria, + gamma0=sampler_settings.gamma0, + ) + elif sampler_settings.sampler_method.lower() == 'independent': + sampler = _rfe_utils.multistate.HybridMultiStateSampler( + mcmc_moves=integrator, + hybrid_factory=hybrid_factory, + online_analysis_interval=sampler_settings.online_analysis_interval, + online_analysis_target_error=sampler_settings.online_analysis_target_error.m, + online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations ) + else: + raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") + + sampler.setup( + n_replicas=sampler_settings.n_replicas, + reporter=reporter, + platform=platform, + lambda_protocol=lambdas, + temperature=to_openmm(protocol_settings.thermo_settings.temperature), + endstates=alchem_settings.unsampled_endstates, + ) + + try: sampler.energy_context_cache = energy_context_cache sampler.sampler_context_cache = sampler_context_cache @@ -684,13 +683,18 @@ def run(self, *, dry=False, verbose=True, reporter.close() # clear GPU contexts - energy_context_cache.empty() - sampler_context_cache.empty() + #energy_context_cache.empty() + #sampler_context_cache.empty() + # TODO: remove once upstream solution in place: https://github.com/choderalab/openmmtools/pull/690 + # replace with above + for context in list(energy_context_cache._data.keys()): + del self._data[context] + for context in list(sampler_context_cache._data.keys()): + del self._data[context] del sampler_context_cache, energy_context_cache del integrator, sampler - if not dry: # pragma: no-cover return { 'nc': nc, From f65a1f2801f324d6dca397838f8230d6b10e23d7 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 11:30:25 -0700 Subject: [PATCH 05/11] Needed to go through one other layer for context deletion --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index f7427f066..371e1839a 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -688,9 +688,9 @@ def run(self, *, dry=False, verbose=True, # TODO: remove once upstream solution in place: https://github.com/choderalab/openmmtools/pull/690 # replace with above for context in list(energy_context_cache._data.keys()): - del self._data[context] + del self._lru._data[context] for context in list(sampler_context_cache._data.keys()): - del self._data[context] + del self._lru._data[context] del sampler_context_cache, energy_context_cache del integrator, sampler From 65d2c91e7030a33f4b13bbef26df38244e6aef5c Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 11:33:32 -0700 Subject: [PATCH 06/11] Missed some :/ --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 371e1839a..927c6ede7 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -687,9 +687,9 @@ def run(self, *, dry=False, verbose=True, #sampler_context_cache.empty() # TODO: remove once upstream solution in place: https://github.com/choderalab/openmmtools/pull/690 # replace with above - for context in list(energy_context_cache._data.keys()): + for context in list(energy_context_cache._lru._data.keys()): del self._lru._data[context] - for context in list(sampler_context_cache._data.keys()): + for context in list(sampler_context_cache._lru._data.keys()): del self._lru._data[context] del sampler_context_cache, energy_context_cache From 7fe0d2cabfda0658339fedefa4a9b9d6f5e915c3 Mon Sep 17 00:00:00 2001 From: David Dotson Date: Thu, 20 Apr 2023 11:36:00 -0700 Subject: [PATCH 07/11] Not enough sleep today... --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 927c6ede7..36dfa011f 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -688,9 +688,9 @@ def run(self, *, dry=False, verbose=True, # TODO: remove once upstream solution in place: https://github.com/choderalab/openmmtools/pull/690 # replace with above for context in list(energy_context_cache._lru._data.keys()): - del self._lru._data[context] + del energy_context_cache._lru._data[context] for context in list(sampler_context_cache._lru._data.keys()): - del self._lru._data[context] + del sampler_context_cache._lru._data[context] del sampler_context_cache, energy_context_cache del integrator, sampler From 73acb1ffe48fa68946c7ab60c35148549e15c5c2 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 20 Apr 2023 22:40:32 +0100 Subject: [PATCH 08/11] clear up global context cache + also remove cache bits in multistate --- .../openmm_rfe/_rfe_utils/multistate.py | 12 ++++---- .../protocols/openmm_rfe/equil_rfe_methods.py | 30 +++++++++++-------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 2994dc43d..1c34bab06 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -34,7 +34,7 @@ def __init__(self, *args, hybrid_factory=None, **kwargs): self._hybrid_factory = hybrid_factory super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) - def setup(self, reporter, platform, lambda_protocol, + def setup(self, reporter, lambda_protocol, temperature=298.15 * unit.kelvin, n_replicas=None, endstates=True, minimization_steps=100): """ @@ -45,8 +45,6 @@ def setup(self, reporter, platform, lambda_protocol, ---------- reporter : OpenMM reporter Simulation reporter to attach to each simulation replica. - platform : openmm.Platform - Platform to perform simulation on. lambda_protocol : LambdaProtocol The lambda protocol to be used for simulation. Default to a default class creation of LambdaProtocol. @@ -84,8 +82,6 @@ class creation of LambdaProtocol. thermodynamic_state_list = [] sampler_state_list = [] - context_cache = cache.ContextCache(platform) - if n_replicas is None: msg = (f"setting number of replicas to number of states: {n_states}") warnings.warn(msg) @@ -118,8 +114,6 @@ class creation of LambdaProtocol. # now generating a sampler_state for each thermodyanmic state, # with relaxed positions - context, context_integrator = context_cache.get_context( - compound_thermostate_copy) minimize(compound_thermostate_copy, sampler_state, max_iterations=minimization_steps) sampler_state_list.append(copy.deepcopy(sampler_state)) @@ -296,4 +290,8 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat ) sampler_state.update_from_context(context) finally: + # cautiously clear out the global context cache too + for context in list(cache.global_context_cache._lru._data.keys()): + del cache.global_context_cache._lru._data[context] + del context, integrator diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 36dfa011f..b85149762 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -572,7 +572,7 @@ def run(self, *, dry=False, verbose=True, checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage, ) - # 10. Get platform and context caches + # 10. Get platform platform = _rfe_utils.compute.get_openmm_platform( protocol_settings.engine_settings.compute_platform ) @@ -581,16 +581,6 @@ def run(self, *, dry=False, verbose=True, # a. get integrator settings integrator_settings = protocol_settings.integrator_settings - # a. Create context caches (energy + sampler) - # Note: these needs to exist on the compute node - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, - ) - - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, - ) - # b. create langevin integrator integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( timestep=to_openmm(integrator_settings.timestep), @@ -635,13 +625,22 @@ def run(self, *, dry=False, verbose=True, sampler.setup( n_replicas=sampler_settings.n_replicas, reporter=reporter, - platform=platform, lambda_protocol=lambdas, temperature=to_openmm(protocol_settings.thermo_settings.temperature), endstates=alchem_settings.unsampled_endstates, ) try: + # Create context caches (energy + sampler) + # Note: these needs to exist on the compute node + energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) + + sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, time_to_live=None, platform=platform, + ) + sampler.energy_context_cache = energy_context_cache sampler.sampler_context_cache = sampler_context_cache @@ -691,9 +690,14 @@ def run(self, *, dry=False, verbose=True, del energy_context_cache._lru._data[context] for context in list(sampler_context_cache._lru._data.keys()): del sampler_context_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list( + openmmtools.cache.global_context_cache._lru._data.keys()): + del cache.global_context_cache._lru._data[context] del sampler_context_cache, energy_context_cache - del integrator, sampler + if not dry: + del integrator, sampler if not dry: # pragma: no-cover return { From 68102084df83f0957a207e80410ad374ebe4ef24 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 20 Apr 2023 22:46:36 +0100 Subject: [PATCH 09/11] oops forgot to change this --- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index b85149762..718d04a30 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -693,7 +693,7 @@ def run(self, *, dry=False, verbose=True, # cautiously clear out the global context cache too for context in list( openmmtools.cache.global_context_cache._lru._data.keys()): - del cache.global_context_cache._lru._data[context] + del openmmtools.cache.global_context_cache._lru._data[context] del sampler_context_cache, energy_context_cache if not dry: From 9abe3c96eb2b1fcbca28d74e02dc079b18b13199 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 21 Apr 2023 23:50:00 +0100 Subject: [PATCH 10/11] some host memory clearning and enforce CPU platform for initial minimization --- .../protocols/openmm_rfe/_rfe_utils/multistate.py | 15 ++++++++------- openfe/protocols/openmm_rfe/equil_rfe_methods.py | 11 +++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 1c34bab06..cb546a8d9 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -118,6 +118,8 @@ class creation of LambdaProtocol. max_iterations=minimization_steps) sampler_state_list.append(copy.deepcopy(sampler_state)) + del compound_thermostate, sampler_state + # making sure number of sampler states equals n_replicas if len(sampler_state_list) != n_replicas: # picking roughly evenly spaced sampler states @@ -277,8 +279,11 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat sampler_state : openmmtools.states.SamplerState The posititions and accompanying state following minimization """ - integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator - context, integrator = cache.global_context_cache.get_context( + # we won't take any steps, so use a simple integrator + integrator = openmm.VerletIntegrator(1.0) + platform = openmm.Platform.getPlatformByName('CPU') + dummy_cache = cache.DummyContextCache(platform=platform) + context, integrator = dummy_cache.get_context( thermodynamic_state, integrator ) try: @@ -290,8 +295,4 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat ) sampler_state.update_from_context(context) finally: - # cautiously clear out the global context cache too - for context in list(cache.global_context_cache._lru._data.keys()): - del cache.global_context_cache._lru._data[context] - - del context, integrator + del context, integrator, dummy_cache diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 718d04a30..148ed422c 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -632,7 +632,6 @@ def run(self, *, dry=False, verbose=True, try: # Create context caches (energy + sampler) - # Note: these needs to exist on the compute node energy_context_cache = openmmtools.cache.ContextCache( capacity=None, time_to_live=None, platform=platform, ) @@ -680,11 +679,15 @@ def run(self, *, dry=False, verbose=True, finally: # close reporter when you're done, prevent file handle clashes reporter.close() + del reporter + + # clean up the analyzer + if not dry: + ana.clear() + del ana # clear GPU contexts - #energy_context_cache.empty() - #sampler_context_cache.empty() - # TODO: remove once upstream solution in place: https://github.com/choderalab/openmmtools/pull/690 + # TODO: use cache.empty() calls when openmmtools #690 is resolved # replace with above for context in list(energy_context_cache._lru._data.keys()): del energy_context_cache._lru._data[context] From b73dfc673ce0d3b945cda91236fdaa86992ebfe9 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 22 Apr 2023 08:10:13 +0100 Subject: [PATCH 11/11] pass platform through to pre-minimization --- .../openmm_rfe/_rfe_utils/multistate.py | 21 +++++++++++++------ .../protocols/openmm_rfe/equil_rfe_methods.py | 1 + 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index cb546a8d9..a269b1b5b 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -36,7 +36,8 @@ def __init__(self, *args, hybrid_factory=None, **kwargs): def setup(self, reporter, lambda_protocol, temperature=298.15 * unit.kelvin, n_replicas=None, - endstates=True, minimization_steps=100): + endstates=True, minimization_steps=100, + minimization_platform="CPU"): """ Setup MultistateSampler based on the input lambda protocol and number of replicas. @@ -58,7 +59,9 @@ class creation of LambdaProtocol. Whether or not to generate unsampled endstates (i.e. dispersion correction). minimization_steps : int - Number of steps to minimize states. + Number of steps to pre-minimize states. + minimization_platform : str + Platform to do the initial pre-minimization with. Attributes ---------- @@ -114,8 +117,10 @@ class creation of LambdaProtocol. # now generating a sampler_state for each thermodyanmic state, # with relaxed positions + # Note: remove once choderalab/openmmtools#672 is completed minimize(compound_thermostate_copy, sampler_state, - max_iterations=minimization_steps) + max_iterations=minimization_steps, + platform_name=minimization_platform) sampler_state_list.append(copy.deepcopy(sampler_state)) del compound_thermostate, sampler_state @@ -257,8 +262,10 @@ def create_endstates(first_thermostate, last_thermostate): return unsampled_endstates -def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: states.SamplerState, - max_iterations: int=100) -> states.SamplerState: +def minimize(thermodynamic_state: states.ThermodynamicState, + sampler_state: states.SamplerState, + max_iterations: int=100, + platform_name: str="CPU") -> states.SamplerState: """ Adapted from perses.dispersed.feptasks.minimize @@ -273,6 +280,8 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat The starting state at which to minimize the system. max_iterations : int, optional, default 100 The maximum number of minimization steps. Default is 100. + platform_name : str + The OpenMM platform name to carry out the minimization with. Returns ------- @@ -281,7 +290,7 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat """ # we won't take any steps, so use a simple integrator integrator = openmm.VerletIntegrator(1.0) - platform = openmm.Platform.getPlatformByName('CPU') + platform = openmm.Platform.getPlatformByName(platform_name) dummy_cache = cache.DummyContextCache(platform=platform) context, integrator = dummy_cache.get_context( thermodynamic_state, integrator diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index 148ed422c..3ebf99251 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -628,6 +628,7 @@ def run(self, *, dry=False, verbose=True, lambda_protocol=lambdas, temperature=to_openmm(protocol_settings.thermo_settings.temperature), endstates=alchem_settings.unsampled_endstates, + minimization_platform=platform.getName(), ) try: