From 308f46950ae3c03cb18e1bfc715c7bb4d26e195f Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Mon, 18 Mar 2024 20:25:40 +0000 Subject: [PATCH] Handle trajectory files during restarts. --- src/somd2/runner/_dynamics.py | 119 +++++++++++++++++++++++++++------- tests/runner/test_restart.py | 17 ++++- 2 files changed, 109 insertions(+), 27 deletions(-) diff --git a/src/somd2/runner/_dynamics.py b/src/somd2/runner/_dynamics.py index a49f7e6..29a93f0 100644 --- a/src/somd2/runner/_dynamics.py +++ b/src/somd2/runner/_dynamics.py @@ -91,9 +91,17 @@ def __init__( raise TypeError("config must be a Config object") self._config = config - # If resarting, subtract the time already run from the total runtime + # If restarting, subtract the time already run from the total runtime if self._config.restart: self._config.runtime = str(self._config.runtime - self._system.time()) + + # Work out the current block number. + self._current_block = 1 + int( + self._system.time().value() / self._config.checkpoint_frequency.value() + ) + else: + self._current_block = 0 + self._lambda_val = lambda_val self._lambda_array = lambda_array self._increment = increment @@ -127,14 +135,14 @@ def increment_filename(base_filename, suffix): raise ValueError("lambda_value not in lambda_array") filenames = {} index = lambda_array.index(lambda_value) - filenames["topology"] = f"system.prm7" + filenames["topology"] = "system.prm7" filenames["checkpoint"] = f"checkpoint_{index}.s3" filenames["energy_traj"] = f"energy_traj_{index}.parquet" + filenames["trajectory"] = f"traj_{index}.dcd" + filenames["trajectory_chunk"] = f"traj_{index}_" if restart: - filenames["trajectory"] = increment_filename(f"traj_{index}", "dcd") filenames["config"] = increment_filename("config", "yaml") else: - filenames["trajectory"] = f"traj_{index}.dcd" filenames["config"] = "config.yaml" return filenames @@ -352,12 +360,11 @@ def generate_lam_vals(lambda_base, increment): _logger.info(f"Running dynamics at {_lam_sym} = {self._lambda_val}") - # Create a list to hold the trajectory chunks. Add the topology file to it. - traj_files = [topology] - if self._config.checkpoint_frequency.value() > 0.0: # Calculate the number of blocks and the remaineder time. - frac = self._config.runtime.value() / self._config.checkpoint_frequency.value() + frac = ( + self._config.runtime.value() / self._config.checkpoint_frequency.value() + ) num_blocks = int(frac) rem = frac - num_blocks @@ -368,8 +375,13 @@ def generate_lam_vals(lambda_base, increment): sire_checkpoint_name = str( _Path(self._config.output_directory) / self._filenames["checkpoint"] ) + # Run num_blocks dynamics and then run a final block if rem > 0 for x in range(int(num_blocks)): + # Add the current block number. + x += self._current_block + + # Run the dynamics. try: self._dyn.run( self._config.checkpoint_frequency, @@ -381,21 +393,29 @@ def generate_lam_vals(lambda_base, increment): ) except: raise + + # Checkpoint. try: + # Set to the current block number if this is a restart. + if x == 0: + x = self._current_block + # Commit the current system and save it to a checkpoint file. self._system = self._dyn.commit() sr.stream.save(self._system, str(sire_checkpoint_name)) # Save the current trajectory chunk to file. if self._config.save_trajectories: - traj_filename = str( - self._config.output_directory - / self._filenames["trajectory"] - ).replace(".dcd", f"_{x}.dcd") + traj_filename = ( + str( + self._config.output_directory + / self._filenames["trajectory_chunk"] + ) + + f"{x}.dcd" + ) sr.save( self._system.trajectory(), traj_filename, format=["DCD"] ) - traj_files.append(traj_filename) # Delete the trajectory from memory. self._system.delete_all_frames() @@ -404,9 +424,9 @@ def generate_lam_vals(lambda_base, increment): df = self._system.energy_trajectory( to_alchemlyb=True, energy_unit="kT" ) - if x == 0: + if x == self._current_block: # Not including speed in checkpoints for now. - f = _dataframe_to_parquet( + parquet = _dataframe_to_parquet( df, metadata={ "attrs": df.attrs, @@ -427,16 +447,18 @@ def generate_lam_vals(lambda_base, increment): self._system.set_property("lambda", self._lambda_val) else: _parquet_append( - f, + parquet, df.iloc[-int(energy_per_block) :], ) _logger.info( - f"Finished block {x+1} of {num_blocks + int(rem > 0)} for {_lam_sym} = {self._lambda_val}" + f"Finished block {x+1} of {self._current_block + num_blocks + int(rem > 0)} " + f"for {_lam_sym} = {self._lambda_val}" ) except: raise # No need to checkpoint here as it is the final block. if rem > 0: + x += 1 try: self._dyn.run( rem, @@ -447,7 +469,26 @@ def generate_lam_vals(lambda_base, increment): auto_fix_minimise=False, ) - f"Finished block {x+1} of {num_blocks + int(rem > 0)} for {_lam_sym} = {self._lambda_val}" + # Save the current trajectory chunk to file. + if self._config.save_trajectories: + traj_filename = ( + str( + self._config.output_directory + / self._filenames["trajectory_chunk"] + ) + + f"{x}.dcd" + ) + sr.save( + self._system.trajectory(), traj_filename, format=["DCD"] + ) + + # Delete the trajectory from memory. + self._system.delete_all_frames() + + _logger.info( + f"Finished block {x+1} of {self._current_block + num_blocks + int(rem > 0)} " + f"for {_lam_sym} = {self._lambda_val}" + ) except: raise self._system = self._dyn.commit() @@ -467,12 +508,42 @@ def generate_lam_vals(lambda_base, increment): # Assemble and save the final energy trajectory. if self._config.save_trajectories: - # Load the chunked trajectory files and concatenate them. - system = sr.load(traj_files) + # Create the final trajectory file name. + traj_filename = str( + self._config.output_directory / self._filenames["trajectory"] + ) + + # Glob for the trajectory chunks. + from glob import glob + + traj_chunks = sorted( + glob( + str( + self._config.output_directory + / f"{self._filenames['trajectory_chunk']}*" + ) + ) + ) + + # If this is a restart, then we need to check for an existing + # trajectory file with the same name. If it exists and is non-empty, + # then copy it to a backup file and prepend it to the list of chunks. + if self._config.restart: + path = _Path(traj_filename) + if path.exists() and path.stat().st_size > 0: + from shutil import copyfile + + copyfile(traj_filename, f"{traj_filename}.bak") + traj_chunks = [f"{traj_filename}.bak"] + traj_chunks + + # Load the topology and chunked trajectory files. + system = sr.load([topology] + traj_chunks) + + # Save the final trajectory to a single file. traj_filename = str( self._config.output_directory / self._filenames["trajectory"] ) - sr.save(self._system.trajectory(), traj_filename, format=["DCD"]) + sr.save(system.trajectory(), traj_filename, format=["DCD"]) # Delete the trajectory from memory. self._system.delete_all_frames() @@ -482,8 +553,9 @@ def generate_lam_vals(lambda_base, increment): del system # Now remove the chunked trajectory files. - for f in traj_files: - _Path(f).unlink() + for chunk in traj_chunks: + if chunk != f"{traj_filename}.bak": + _Path(chunk).unlink() # Add config and lambda value to the system properties. self._system.add_shared_property( @@ -493,7 +565,6 @@ def generate_lam_vals(lambda_base, increment): # Save the final system to checkpoint file. sr.stream.save(self._system, sire_checkpoint_name) - _logger.debug(f"Properties on system: {self._system.property_keys()}") df = self._system.energy_trajectory(to_alchemlyb=True, energy_unit="kT") return df diff --git a/tests/runner/test_restart.py b/tests/runner/test_restart.py index d633dc9..b42e29a 100644 --- a/tests/runner/test_restart.py +++ b/tests/runner/test_restart.py @@ -44,6 +44,11 @@ def test_restart(mols, request): num_entries = len(energy_traj_1.index) + # Load the trajectory. + traj_1 = sr.load( + [str(Path(tmpdir) / "system.prm7"), str(Path(tmpdir) / "traj_0.dcd")] + ) + # Check that both config and lambda have been written # as properties to the streamed checkpoint file. checkpoint = sr.stream.load(str(Path(tmpdir) / "checkpoint_0.s3")) @@ -79,11 +84,17 @@ def test_restart(mols, request): # Check that first half of energy trajectory is the same assert energy_traj_1.equals(energy_traj_2.iloc[:num_entries]) + # Check that second energy trajectory is twice as long as the first assert len(energy_traj_2.index) == 2 * num_entries - # Check that a second trajectory was written and that the first still exists - assert Path.exists(Path(tmpdir) / "traj_0.dcd") - assert Path.exists(Path(tmpdir) / "traj_0_1.dcd") + + # Reload the trajectory. + traj_2 = sr.load( + [str(Path(tmpdir) / "system.prm7"), str(Path(tmpdir) / "traj_0.dcd")] + ) + + # Check that the trajectory is twice as long as the first. + assert traj_2.num_frames() == 2 * traj_1.num_frames() config_difftimestep = config_new.copy() config_difftimestep["runtime"] = "36fs"