Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup contexts upon calculation completion, failure #354

Merged
merged 13 commits into from
Apr 22, 2023
34 changes: 18 additions & 16 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

def setup(self, reporter, platform, lambda_protocol,
def setup(self, reporter, lambda_protocol,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating a context here was absolutely unecessary, so it's been removed

temperature=298.15 * unit.kelvin, n_replicas=None,
endstates=True, minimization_steps=100):
"""
Expand All @@ -45,8 +45,6 @@ def setup(self, reporter, platform, lambda_protocol,
----------
reporter : OpenMM reporter
Simulation reporter to attach to each simulation replica.
platform : openmm.Platform
Platform to perform simulation on.
lambda_protocol : LambdaProtocol
The lambda protocol to be used for simulation. Default to a default
class creation of LambdaProtocol.
Expand Down Expand Up @@ -84,8 +82,6 @@ class creation of LambdaProtocol.
thermodynamic_state_list = []
sampler_state_list = []

context_cache = cache.ContextCache(platform)

if n_replicas is None:
msg = (f"setting number of replicas to number of states: {n_states}")
warnings.warn(msg)
Expand Down Expand Up @@ -118,12 +114,12 @@ class creation of LambdaProtocol.

# now generating a sampler_state for each thermodyanmic state,
# with relaxed positions
context, context_integrator = context_cache.get_context(
compound_thermostate_copy)
minimize(compound_thermostate_copy, sampler_state,
max_iterations=minimization_steps)
sampler_state_list.append(copy.deepcopy(sampler_state))

del compound_thermostate, sampler_state
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be fine.


# making sure number of sampler states equals n_replicas
if len(sampler_state_list) != n_replicas:
# picking roughly evenly spaced sampler states
Expand Down Expand Up @@ -283,14 +279,20 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat
sampler_state : openmmtools.states.SamplerState
The posititions and accompanying state following minimization
"""
integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator
context, integrator = cache.global_context_cache.get_context(
# we won't take any steps, so use a simple integrator
integrator = openmm.VerletIntegrator(1.0)
platform = openmm.Platform.getPlatformByName('CPU')
dummy_cache = cache.DummyContextCache(platform=platform)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a dummy context cache for the convenience of being able to pass a thermodynamic state object rather than doubling up efforts and doing the 3-4 extra lines here ourselves

context, integrator = dummy_cache.get_context(
thermodynamic_state, integrator
)
sampler_state.apply_to_context(
context, ignore_velocities=True
)
openmm.LocalEnergyMinimizer.minimize(
context, maxIterations=max_iterations
)
sampler_state.update_from_context(context)
try:
sampler_state.apply_to_context(
context, ignore_velocities=True
)
openmm.LocalEnergyMinimizer.minimize(
context, maxIterations=max_iterations
)
sampler_state.update_from_context(context)
finally:
del context, integrator, dummy_cache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being a bit overly cautious but it'll do I think.

139 changes: 80 additions & 59 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True,

# a. check timestep correctness + that
# equilibration & production are divisible by n_steps
prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
stateA = self._inputs['stateA']
stateB = self._inputs['stateB']
mapping = self._inputs['ligandmapping']

forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings
alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings
system_settings: SystemSettings = prototol_settings.system_settings
solvation_settings: SolvationSettings = prototol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = prototol_settings.simulation_settings
timestep = prototol_settings.integrator_settings.timestep
mc_steps = prototol_settings.integrator_settings.n_steps.m
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings
alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings
system_settings: SystemSettings = protocol_settings.system_settings
solvation_settings: SolvationSettings = protocol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = protocol_settings.simulation_settings
timestep = protocol_settings.integrator_settings.timestep
mc_steps = protocol_settings.integrator_settings.n_steps.m

# is the timestep good for the mass?
if forcefield_settings.hydrogen_mass < 3.0:
Expand Down Expand Up @@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True,
if 'solvent' in stateA.components:
hybrid_factory.hybrid_system.addForce(
openmm.MonteCarloBarostat(
prototol_settings.thermo_settings.pressure.to(unit.bar).m,
prototol_settings.thermo_settings.temperature.m,
prototol_settings.integrator_settings.barostat_frequency.m,
protocol_settings.thermo_settings.pressure.to(unit.bar).m,
protocol_settings.thermo_settings.temperature.m,
protocol_settings.integrator_settings.barostat_frequency.m,
)
)

Expand Down Expand Up @@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True,
checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage,
)

# 10. Get platform and context caches
# 10. Get platform
platform = _rfe_utils.compute.get_openmm_platform(
prototol_settings.engine_settings.compute_platform
)

# a. Create context caches (energy + sampler)
# Note: these needs to exist on the compute node
energy_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

sampler_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
protocol_settings.engine_settings.compute_platform
)

# 11. Set the integrator
dotsdl marked this conversation as resolved.
Show resolved Hide resolved
# a. get integrator settings
integrator_settings = prototol_settings.integrator_settings
integrator_settings = protocol_settings.integrator_settings

# b. create langevin integrator
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove(
Expand Down Expand Up @@ -635,59 +625,90 @@ def run(self, *, dry=False, verbose=True,
sampler.setup(
n_replicas=sampler_settings.n_replicas,
reporter=reporter,
platform=platform,
lambda_protocol=lambdas,
temperature=to_openmm(prototol_settings.thermo_settings.temperature),
temperature=to_openmm(protocol_settings.thermo_settings.temperature),
endstates=alchem_settings.unsampled_endstates,
)

sampler.energy_context_cache = energy_context_cache
sampler.sampler_context_cache = sampler_context_cache
try:
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
# Create context caches (energy + sampler)
energy_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")
sampler_context_cache = openmmtools.cache.ContextCache(
capacity=None, time_to_live=None, platform=platform,
)

sampler.minimize(max_iterations=sim_settings.minimization_steps)
sampler.energy_context_cache = energy_context_cache
dotsdl marked this conversation as resolved.
Show resolved Hide resolved
sampler.sampler_context_cache = sampler_context_cache

# equilibrate
if verbose:
logger.info("equilibrating systems")
if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")

sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore
sampler.minimize(max_iterations=sim_settings.minimization_steps)

# production
if verbose:
logger.info("running production phase")
# equilibrate
if verbose:
logger.info("equilibrating systems")

sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore

# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')
# production
if verbose:
logger.info("running production phase")

# close reporter when you're done
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore

# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')

nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
else:
# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
finally:
# close reporter when you're done, prevent file handle clashes
reporter.close()
del reporter

# clean up the analyzer
if not dry:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a bit of a cleanup, but I'll do it in the upcoming refactor.

ana.clear()
del ana

# clear GPU contexts
# TODO: use cache.empty() calls when openmmtools #690 is resolved
# replace with above
for context in list(energy_context_cache._lru._data.keys()):
del energy_context_cache._lru._data[context]
for context in list(sampler_context_cache._lru._data.keys()):
del sampler_context_cache._lru._data[context]
# cautiously clear out the global context cache too
for context in list(
openmmtools.cache.global_context_cache._lru._data.keys()):
del openmmtools.cache.global_context_cache._lru._data[context]

del sampler_context_cache, energy_context_cache
if not dry:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't delete the integrator on dry, one would hope that alchemiscale style services wouldn't be running in dry mode anyways?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't see a scenario where we would use dry on alchemiscale, so think we're good here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems more of a user-host thing for debug.

del integrator, sampler

nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
if not dry: # pragma: no-cover
return {
'nc': nc,
'last_checkpoint': chk,
'unit_estimate': est,
}
else:
# close reporter when you're done, prevent file handle clashes
reporter.close()

# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
return {'debug': {'sampler': sampler}}

def _execute(
Expand Down