Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup contexts upon calculation completion, failure #354

Merged
merged 13 commits into from
Apr 22, 2023
199 changes: 103 additions & 96 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,20 @@ def run(self, *, dry=False, verbose=True,

# a. check timestep correctness + that
# equilibration & production are divisible by n_steps
prototol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['settings']
stateA = self._inputs['stateA']
stateB = self._inputs['stateB']
mapping = self._inputs['ligandmapping']

forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = prototol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = prototol_settings.thermo_settings
alchem_settings: AlchemicalSettings = prototol_settings.alchemical_settings
system_settings: SystemSettings = prototol_settings.system_settings
solvation_settings: SolvationSettings = prototol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = prototol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = prototol_settings.simulation_settings
timestep = prototol_settings.integrator_settings.timestep
mc_steps = prototol_settings.integrator_settings.n_steps.m
forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings
thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings
alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings
system_settings: SystemSettings = protocol_settings.system_settings
solvation_settings: SolvationSettings = protocol_settings.solvation_settings
sampler_settings: AlchemicalSamplerSettings = protocol_settings.alchemical_sampler_settings
sim_settings: SimulationSettings = protocol_settings.simulation_settings
timestep = protocol_settings.integrator_settings.timestep
mc_steps = protocol_settings.integrator_settings.n_steps.m

# is the timestep good for the mass?
if forcefield_settings.hydrogen_mass < 3.0:
Expand Down Expand Up @@ -536,9 +536,9 @@ def run(self, *, dry=False, verbose=True,
if 'solvent' in stateA.components:
hybrid_factory.hybrid_system.addForce(
openmm.MonteCarloBarostat(
prototol_settings.thermo_settings.pressure.to(unit.bar).m,
prototol_settings.thermo_settings.temperature.m,
prototol_settings.integrator_settings.barostat_frequency.m,
protocol_settings.thermo_settings.pressure.to(unit.bar).m,
protocol_settings.thermo_settings.temperature.m,
protocol_settings.integrator_settings.barostat_frequency.m,
)
)

Expand Down Expand Up @@ -574,9 +574,13 @@ def run(self, *, dry=False, verbose=True,

# 10. Get platform and context caches
platform = _rfe_utils.compute.get_openmm_platform(
prototol_settings.engine_settings.compute_platform
protocol_settings.engine_settings.compute_platform
)

# 11. Set the integrator
dotsdl marked this conversation as resolved.
Show resolved Hide resolved
# a. get integrator settings
integrator_settings = protocol_settings.integrator_settings

# a. Create context caches (energy + sampler)
# Note: these needs to exist on the compute node
energy_context_cache = openmmtools.cache.ContextCache(
Expand All @@ -587,107 +591,110 @@ def run(self, *, dry=False, verbose=True,
capacity=None, time_to_live=None, platform=platform,
)

# 11. Set the integrator
# a. get integrator settings
integrator_settings = prototol_settings.integrator_settings

# b. create langevin integrator
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove(
timestep=to_openmm(integrator_settings.timestep),
collision_rate=to_openmm(integrator_settings.collision_rate),
n_steps=integrator_settings.n_steps.m,
reassign_velocities=integrator_settings.reassign_velocities,
n_restart_attempts=integrator_settings.n_restart_attempts,
constraint_tolerance=integrator_settings.constraint_tolerance,
splitting=integrator_settings.splitting
)

# 12. Create sampler
if sampler_settings.sampler_method.lower() == "repex":
sampler = _rfe_utils.multistate.HybridRepexSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations
)
elif sampler_settings.sampler_method.lower() == "sams":
sampler = _rfe_utils.multistate.HybridSAMSSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations,
flatness_criteria=sampler_settings.flatness_criteria,
gamma0=sampler_settings.gamma0,
try:
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
IAlibay marked this conversation as resolved.
Show resolved Hide resolved
# b. create langevin integrator
integrator = openmmtools.mcmc.LangevinSplittingDynamicsMove(
timestep=to_openmm(integrator_settings.timestep),
collision_rate=to_openmm(integrator_settings.collision_rate),
n_steps=integrator_settings.n_steps.m,
reassign_velocities=integrator_settings.reassign_velocities,
n_restart_attempts=integrator_settings.n_restart_attempts,
constraint_tolerance=integrator_settings.constraint_tolerance,
splitting=integrator_settings.splitting
)
elif sampler_settings.sampler_method.lower() == 'independent':
sampler = _rfe_utils.multistate.HybridMultiStateSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations

# 12. Create sampler
if sampler_settings.sampler_method.lower() == "repex":
sampler = _rfe_utils.multistate.HybridRepexSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations
)
elif sampler_settings.sampler_method.lower() == "sams":
sampler = _rfe_utils.multistate.HybridSAMSSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations,
flatness_criteria=sampler_settings.flatness_criteria,
gamma0=sampler_settings.gamma0,
)
elif sampler_settings.sampler_method.lower() == 'independent':
sampler = _rfe_utils.multistate.HybridMultiStateSampler(
mcmc_moves=integrator,
hybrid_factory=hybrid_factory,
online_analysis_interval=sampler_settings.online_analysis_interval,
online_analysis_target_error=sampler_settings.online_analysis_target_error.m,
online_analysis_minimum_iterations=sampler_settings.online_analysis_minimum_iterations
)

else:
raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}")

sampler.setup(
n_replicas=sampler_settings.n_replicas,
reporter=reporter,
platform=platform,
lambda_protocol=lambdas,
temperature=to_openmm(protocol_settings.thermo_settings.temperature),
endstates=alchem_settings.unsampled_endstates,
)

else:
raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}")

sampler.setup(
n_replicas=sampler_settings.n_replicas,
reporter=reporter,
platform=platform,
lambda_protocol=lambdas,
temperature=to_openmm(prototol_settings.thermo_settings.temperature),
endstates=alchem_settings.unsampled_endstates,
)
sampler.energy_context_cache = energy_context_cache
dotsdl marked this conversation as resolved.
Show resolved Hide resolved
sampler.sampler_context_cache = sampler_context_cache

sampler.energy_context_cache = energy_context_cache
sampler.sampler_context_cache = sampler_context_cache
if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")

if not dry: # pragma: no-cover
# minimize
if verbose:
logger.info("minimizing systems")
sampler.minimize(max_iterations=sim_settings.minimization_steps)

sampler.minimize(max_iterations=sim_settings.minimization_steps)
# equilibrate
if verbose:
logger.info("equilibrating systems")

# equilibrate
if verbose:
logger.info("equilibrating systems")
sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore

sampler.equilibrate(int(equil_steps.m / mc_steps)) # type: ignore
# production
if verbose:
logger.info("running production phase")

# production
if verbose:
logger.info("running production phase")
sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore

sampler.extend(int(prod_steps.m / mc_steps)) # type: ignore
# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')

# calculate estimate of results from this individual unit
ana = multistate.MultiStateSamplerAnalyzer(reporter)
est, _ = ana.get_free_energy()
est = (est[0, -1] * ana.kT).in_units_of(omm_unit.kilocalories_per_mole)
est = ensure_quantity(est, 'openff')
# close reporter when you're done
reporter.close()

# close reporter when you're done
reporter.close()
nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
else:
# close reporter when you're done, prevent file handle clashes
reporter.close()

nc = shared_basepath / sim_settings.output_filename
chk = shared_basepath / sim_settings.checkpoint_storage
# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
finally:
energy_context_cache.empty()
sampler_context_cache.empty()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like it should get (eventually) upstreamed as these context objects are used using a context manager pattern

Copy link
Member

@IAlibay IAlibay Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in the short term solution, it might be good to at least make sure that the context is cleanly removed in multistatesampler.__del__

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, upstreaming here: choderalab/openmmtools#690

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


if not dry: # pragma: no-cover
return {
'nc': nc,
'last_checkpoint': chk,
'unit_estimate': est,
}
else:
# close reporter when you're done, prevent file handle clashes
reporter.close()

# clean up the reporter file
fns = [shared_basepath / sim_settings.output_filename,
shared_basepath / sim_settings.checkpoint_storage]
for fn in fns:
os.remove(fn)
return {'debug': {'sampler': sampler}}

def _execute(
Expand Down