-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup contexts upon calculation completion, failure #354
Changes from 12 commits
c76084e
bef1fb2
8e8df34
abd3ae5
d557843
ea3feb5
f65a1f2
65d2c91
7fe0d2c
73acb1f
6810208
9abe3c9
b73dfc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ def __init__(self, *args, hybrid_factory=None, **kwargs): | |
self._hybrid_factory = hybrid_factory | ||
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) | ||
|
||
def setup(self, reporter, platform, lambda_protocol, | ||
def setup(self, reporter, lambda_protocol, | ||
temperature=298.15 * unit.kelvin, n_replicas=None, | ||
endstates=True, minimization_steps=100): | ||
""" | ||
|
@@ -45,8 +45,6 @@ def setup(self, reporter, platform, lambda_protocol, | |
---------- | ||
reporter : OpenMM reporter | ||
Simulation reporter to attach to each simulation replica. | ||
platform : openmm.Platform | ||
Platform to perform simulation on. | ||
lambda_protocol : LambdaProtocol | ||
The lambda protocol to be used for simulation. Default to a default | ||
class creation of LambdaProtocol. | ||
|
@@ -84,8 +82,6 @@ class creation of LambdaProtocol. | |
thermodynamic_state_list = [] | ||
sampler_state_list = [] | ||
|
||
context_cache = cache.ContextCache(platform) | ||
|
||
if n_replicas is None: | ||
msg = (f"setting number of replicas to number of states: {n_states}") | ||
warnings.warn(msg) | ||
|
@@ -118,12 +114,12 @@ class creation of LambdaProtocol. | |
|
||
# now generating a sampler_state for each thermodyanmic state, | ||
# with relaxed positions | ||
context, context_integrator = context_cache.get_context( | ||
compound_thermostate_copy) | ||
minimize(compound_thermostate_copy, sampler_state, | ||
max_iterations=minimization_steps) | ||
sampler_state_list.append(copy.deepcopy(sampler_state)) | ||
|
||
del compound_thermostate, sampler_state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that should be fine. |
||
|
||
# making sure number of sampler states equals n_replicas | ||
if len(sampler_state_list) != n_replicas: | ||
# picking roughly evenly spaced sampler states | ||
|
@@ -283,14 +279,20 @@ def minimize(thermodynamic_state: states.ThermodynamicState, sampler_state: stat | |
sampler_state : openmmtools.states.SamplerState | ||
The posititions and accompanying state following minimization | ||
""" | ||
integrator = openmm.VerletIntegrator(1.0) #we won't take any steps, so use a simple integrator | ||
context, integrator = cache.global_context_cache.get_context( | ||
# we won't take any steps, so use a simple integrator | ||
integrator = openmm.VerletIntegrator(1.0) | ||
platform = openmm.Platform.getPlatformByName('CPU') | ||
dummy_cache = cache.DummyContextCache(platform=platform) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using a dummy context cache for the convenience of being able to pass a thermodynamic state object rather than doubling up efforts and doing the 3-4 extra lines here ourselves |
||
context, integrator = dummy_cache.get_context( | ||
thermodynamic_state, integrator | ||
) | ||
sampler_state.apply_to_context( | ||
context, ignore_velocities=True | ||
) | ||
openmm.LocalEnergyMinimizer.minimize( | ||
context, maxIterations=max_iterations | ||
) | ||
sampler_state.update_from_context(context) | ||
try: | ||
sampler_state.apply_to_context( | ||
context, ignore_velocities=True | ||
) | ||
openmm.LocalEnergyMinimizer.minimize( | ||
context, maxIterations=max_iterations | ||
) | ||
sampler_state.update_from_context(context) | ||
finally: | ||
del context, integrator, dummy_cache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Being a bit overly cautious but it'll do I think. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True, | |
|
||
# a. check timestep correctness + that | ||
# equilibration & production are divisible by n_steps | ||
prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] | ||
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings'] | ||
stateA = self._inputs['stateA'] | ||
stateB = self._inputs['stateB'] | ||
mapping = self._inputs['ligandmapping'] | ||
|
||
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings | ||
thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings | ||
alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings | ||
system_settings: SystemSettings = prototol_settings.system_settings | ||
solvation_settings: SolvationSettings = prototol_settings.solvation_settings | ||
sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings | ||
sim_settings: SimulationSettings = prototol_settings.simulation_settings | ||
timestep = prototol_settings.integrator_settings.timestep | ||
mc_steps = prototol_settings.integrator_settings.n_steps.m | ||
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings | ||
thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings | ||
alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings | ||
system_settings: SystemSettings = protocol_settings.system_settings | ||
solvation_settings: SolvationSettings = protocol_settings.solvation_settings | ||
sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings | ||
sim_settings: SimulationSettings = protocol_settings.simulation_settings | ||
timestep = protocol_settings.integrator_settings.timestep | ||
mc_steps = protocol_settings.integrator_settings.n_steps.m | ||
|
||
# is the timestep good for the mass? | ||
if forcefield_settings.hydrogen_mass < 3.0: | ||
|
@@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True, | |
if 'solvent' in stateA.components: | ||
hybrid_factory.hybrid_system.addForce( | ||
openmm.MonteCarloBarostat( | ||
prototol_settings.thermo_settings.pressure.to(unit.bar).m, | ||
prototol_settings.thermo_settings.temperature.m, | ||
prototol_settings.integrator_settings.barostat_frequency.m, | ||
protocol_settings.thermo_settings.pressure.to(unit.bar).m, | ||
protocol_settings.thermo_settings.temperature.m, | ||
protocol_settings.integrator_settings.barostat_frequency.m, | ||
) | ||
) | ||
|
||
|
@@ -572,24 +572,14 @@ def run(self, *, dry=False, verbose=True, | |
checkpoint_storage=shared_basepath / sim_settings.checkpoint_storage, | ||
) | ||
|
||
# 10. Get platform and context caches | ||
# 10. Get platform | ||
platform = _rfe_utils.compute.get_openmm_platform( | ||
prototol_settings.engine_settings.compute_platform | ||
) | ||
|
||
# a. Create context caches (energy + sampler) | ||
# Note: these needs to exist on the compute node | ||
energy_context_cache = openmmtools.cache.ContextCache( | ||
capacity=None, time_to_live=None, platform=platform, | ||
) | ||
|
||
sampler_context_cache = openmmtools.cache.ContextCache( | ||
capacity=None, time_to_live=None, platform=platform, | ||
protocol_settings.engine_settings.compute_platform | ||
) | ||
|
||
# 11. Set the integrator | ||
dotsdl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# a. get integrator settings | ||
integrator_settings = prototol_settings.integrator_settings | ||
integrator_settings = protocol_settings.integrator_settings | ||
|
||
# b. create langevin integrator | ||
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove( | ||
|
@@ -635,59 +625,90 @@ def run(self, *, dry=False, verbose=True, | |
sampler.setup( | ||
n_replicas=sampler_settings.n_replicas, | ||
reporter=reporter, | ||
platform=platform, | ||
lambda_protocol=lambdas, | ||
temperature=to_openmm(prototol_settings.thermo_settings.temperature), | ||
temperature=to_openmm(protocol_settings.thermo_settings.temperature), | ||
endstates=alchem_settings.unsampled_endstates, | ||
) | ||
|
||
sampler.energy_context_cache = energy_context_cache | ||
sampler.sampler_context_cache = sampler_context_cache | ||
try: | ||
IAlibay marked this conversation as resolved.
Show resolved
Hide resolved
IAlibay marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Create context caches (energy + sampler) | ||
energy_context_cache = openmmtools.cache.ContextCache( | ||
capacity=None, time_to_live=None, platform=platform, | ||
) | ||
|
||
if not dry: # pragma: no-cover | ||
# minimize | ||
if verbose: | ||
logger.info("minimizing systems") | ||
sampler_context_cache = openmmtools.cache.ContextCache( | ||
capacity=None, time_to_live=None, platform=platform, | ||
) | ||
|
||
sampler.minimize(max_iterations=sim_settings.minimization_steps) | ||
sampler.energy_context_cache = energy_context_cache | ||
dotsdl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sampler.sampler_context_cache = sampler_context_cache | ||
|
||
# equilibrate | ||
if verbose: | ||
logger.info("equilibrating systems") | ||
if not dry: # pragma: no-cover | ||
# minimize | ||
if verbose: | ||
logger.info("minimizing systems") | ||
|
||
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore | ||
sampler.minimize(max_iterations=sim_settings.minimization_steps) | ||
|
||
# production | ||
if verbose: | ||
logger.info("running production phase") | ||
# equilibrate | ||
if verbose: | ||
logger.info("equilibrating systems") | ||
|
||
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore | ||
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore | ||
|
||
# calculate estimate of results from this individual unit | ||
ana = multistate.MultiStateSamplerAnalyzer(reporter) | ||
est, _ = ana.get_free_energy() | ||
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) | ||
est = ensure_quantity(est, 'openff') | ||
# production | ||
if verbose: | ||
logger.info("running production phase") | ||
|
||
# close reporter when you're done | ||
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore | ||
|
||
# calculate estimate of results from this individual unit | ||
ana = multistate.MultiStateSamplerAnalyzer(reporter) | ||
est, _ = ana.get_free_energy() | ||
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole) | ||
est = ensure_quantity(est, 'openff') | ||
|
||
nc = shared_basepath / sim_settings.output_filename | ||
chk = shared_basepath / sim_settings.checkpoint_storage | ||
else: | ||
# clean up the reporter file | ||
fns = [shared_basepath / sim_settings.output_filename, | ||
shared_basepath / sim_settings.checkpoint_storage] | ||
for fn in fns: | ||
os.remove(fn) | ||
finally: | ||
# close reporter when you're done, prevent file handle clashes | ||
reporter.close() | ||
del reporter | ||
|
||
# clean up the analyzer | ||
if not dry: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs a bit of a cleanup, but I'll do it in the upcoming refactor. |
||
ana.clear() | ||
del ana | ||
|
||
# clear GPU contexts | ||
# TODO: use cache.empty() calls when openmmtools #690 is resolved | ||
# replace with above | ||
for context in list(energy_context_cache._lru._data.keys()): | ||
del energy_context_cache._lru._data[context] | ||
for context in list(sampler_context_cache._lru._data.keys()): | ||
del sampler_context_cache._lru._data[context] | ||
# cautiously clear out the global context cache too | ||
for context in list( | ||
openmmtools.cache.global_context_cache._lru._data.keys()): | ||
del openmmtools.cache.global_context_cache._lru._data[context] | ||
|
||
del sampler_context_cache, energy_context_cache | ||
if not dry: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. won't delete the integrator on dry, one would hope that alchemiscale style services wouldn't be running in dry mode anyways? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I don't see a scenario where we would use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems more of a user-host thing for debug. |
||
del integrator, sampler | ||
|
||
nc = shared_basepath / sim_settings.output_filename | ||
chk = shared_basepath / sim_settings.checkpoint_storage | ||
if not dry: # pragma: no-cover | ||
return { | ||
'nc': nc, | ||
'last_checkpoint': chk, | ||
'unit_estimate': est, | ||
} | ||
else: | ||
# close reporter when you're done, prevent file handle clashes | ||
reporter.close() | ||
|
||
# clean up the reporter file | ||
fns = [shared_basepath / sim_settings.output_filename, | ||
shared_basepath / sim_settings.checkpoint_storage] | ||
for fn in fns: | ||
os.remove(fn) | ||
return {'debug': {'sampler': sampler}} | ||
|
||
def _execute( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
creating a context here was absolutely unecessary, so it's been removed