Skip to content

Commit

Permalink
Handle trajectory files during restarts.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Mar 19, 2024
1 parent fa16960 commit 308f469
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 27 deletions.
119 changes: 95 additions & 24 deletions src/somd2/runner/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,17 @@ def __init__(
raise TypeError("config must be a Config object")

self._config = config
# If resarting, subtract the time already run from the total runtime
# If restarting, subtract the time already run from the total runtime
if self._config.restart:
self._config.runtime = str(self._config.runtime - self._system.time())

# Work out the current block number.
self._current_block = 1 + int(
self._system.time().value() / self._config.checkpoint_frequency.value()
)
else:
self._current_block = 0

self._lambda_val = lambda_val
self._lambda_array = lambda_array
self._increment = increment
Expand Down Expand Up @@ -127,14 +135,14 @@ def increment_filename(base_filename, suffix):
raise ValueError("lambda_value not in lambda_array")
filenames = {}
index = lambda_array.index(lambda_value)
filenames["topology"] = f"system.prm7"
filenames["topology"] = "system.prm7"
filenames["checkpoint"] = f"checkpoint_{index}.s3"
filenames["energy_traj"] = f"energy_traj_{index}.parquet"
filenames["trajectory"] = f"traj_{index}.dcd"
filenames["trajectory_chunk"] = f"traj_{index}_"
if restart:
filenames["trajectory"] = increment_filename(f"traj_{index}", "dcd")
filenames["config"] = increment_filename("config", "yaml")
else:
filenames["trajectory"] = f"traj_{index}.dcd"
filenames["config"] = "config.yaml"
return filenames

Expand Down Expand Up @@ -352,12 +360,11 @@ def generate_lam_vals(lambda_base, increment):

_logger.info(f"Running dynamics at {_lam_sym} = {self._lambda_val}")

# Create a list to hold the trajectory chunks. Add the topology file to it.
traj_files = [topology]

if self._config.checkpoint_frequency.value() > 0.0:
# Calculate the number of blocks and the remaineder time.
frac = self._config.runtime.value() / self._config.checkpoint_frequency.value()
frac = (
self._config.runtime.value() / self._config.checkpoint_frequency.value()
)
num_blocks = int(frac)
rem = frac - num_blocks

Expand All @@ -368,8 +375,13 @@ def generate_lam_vals(lambda_base, increment):
sire_checkpoint_name = str(
_Path(self._config.output_directory) / self._filenames["checkpoint"]
)

# Run num_blocks dynamics and then run a final block if rem > 0
for x in range(int(num_blocks)):
# Add the current block number.
x += self._current_block

# Run the dynamics.
try:
self._dyn.run(
self._config.checkpoint_frequency,
Expand All @@ -381,21 +393,29 @@ def generate_lam_vals(lambda_base, increment):
)
except:
raise

# Checkpoint.
try:
# Set to the current block number if this is a restart.
if x == 0:
x = self._current_block

# Commit the current system and save it to a checkpoint file.
self._system = self._dyn.commit()
sr.stream.save(self._system, str(sire_checkpoint_name))

# Save the current trajectory chunk to file.
if self._config.save_trajectories:
traj_filename = str(
self._config.output_directory
/ self._filenames["trajectory"]
).replace(".dcd", f"_{x}.dcd")
traj_filename = (
str(
self._config.output_directory
/ self._filenames["trajectory_chunk"]
)
+ f"{x}.dcd"
)
sr.save(
self._system.trajectory(), traj_filename, format=["DCD"]
)
traj_files.append(traj_filename)

# Delete the trajectory from memory.
self._system.delete_all_frames()
Expand All @@ -404,9 +424,9 @@ def generate_lam_vals(lambda_base, increment):
df = self._system.energy_trajectory(
to_alchemlyb=True, energy_unit="kT"
)
if x == 0:
if x == self._current_block:
# Not including speed in checkpoints for now.
f = _dataframe_to_parquet(
parquet = _dataframe_to_parquet(
df,
metadata={
"attrs": df.attrs,
Expand All @@ -427,16 +447,18 @@ def generate_lam_vals(lambda_base, increment):
self._system.set_property("lambda", self._lambda_val)
else:
_parquet_append(
f,
parquet,
df.iloc[-int(energy_per_block) :],
)
_logger.info(
f"Finished block {x+1} of {num_blocks + int(rem > 0)} for {_lam_sym} = {self._lambda_val}"
f"Finished block {x+1} of {self._current_block + num_blocks + int(rem > 0)} "
f"for {_lam_sym} = {self._lambda_val}"
)
except:
raise
# No need to checkpoint here as it is the final block.
if rem > 0:
x += 1
try:
self._dyn.run(
rem,
Expand All @@ -447,7 +469,26 @@ def generate_lam_vals(lambda_base, increment):
auto_fix_minimise=False,
)

f"Finished block {x+1} of {num_blocks + int(rem > 0)} for {_lam_sym} = {self._lambda_val}"
# Save the current trajectory chunk to file.
if self._config.save_trajectories:
traj_filename = (
str(
self._config.output_directory
/ self._filenames["trajectory_chunk"]
)
+ f"{x}.dcd"
)
sr.save(
self._system.trajectory(), traj_filename, format=["DCD"]
)

# Delete the trajectory from memory.
self._system.delete_all_frames()

_logger.info(
f"Finished block {x+1} of {self._current_block + num_blocks + int(rem > 0)} "
f"for {_lam_sym} = {self._lambda_val}"
)
except:
raise
self._system = self._dyn.commit()
Expand All @@ -467,12 +508,42 @@ def generate_lam_vals(lambda_base, increment):

# Assemble and save the final energy trajectory.
if self._config.save_trajectories:
# Load the chunked trajectory files and concatenate them.
system = sr.load(traj_files)
# Create the final trajectory file name.
traj_filename = str(
self._config.output_directory / self._filenames["trajectory"]
)

# Glob for the trajectory chunks.
from glob import glob

traj_chunks = sorted(
glob(
str(
self._config.output_directory
/ f"{self._filenames['trajectory_chunk']}*"
)
)
)

# If this is a restart, then we need to check for an existing
# trajectory file with the same name. If it exists and is non-empty,
# then copy it to a backup file and prepend it to the list of chunks.
if self._config.restart:
path = _Path(traj_filename)
if path.exists() and path.stat().st_size > 0:
from shutil import copyfile

copyfile(traj_filename, f"{traj_filename}.bak")
traj_chunks = [f"{traj_filename}.bak"] + traj_chunks

# Load the topology and chunked trajectory files.
system = sr.load([topology] + traj_chunks)

# Save the final trajectory to a single file.
traj_filename = str(
self._config.output_directory / self._filenames["trajectory"]
)
sr.save(self._system.trajectory(), traj_filename, format=["DCD"])
sr.save(system.trajectory(), traj_filename, format=["DCD"])

# Delete the trajectory from memory.
self._system.delete_all_frames()
Expand All @@ -482,8 +553,9 @@ def generate_lam_vals(lambda_base, increment):
del system

# Now remove the chunked trajectory files.
for f in traj_files:
_Path(f).unlink()
for chunk in traj_chunks:
if chunk != f"{traj_filename}.bak":
_Path(chunk).unlink()

# Add config and lambda value to the system properties.
self._system.add_shared_property(
Expand All @@ -493,7 +565,6 @@ def generate_lam_vals(lambda_base, increment):

# Save the final system to checkpoint file.
sr.stream.save(self._system, sire_checkpoint_name)
_logger.debug(f"Properties on system: {self._system.property_keys()}")
df = self._system.energy_trajectory(to_alchemlyb=True, energy_unit="kT")
return df

Expand Down
17 changes: 14 additions & 3 deletions tests/runner/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def test_restart(mols, request):

num_entries = len(energy_traj_1.index)

# Load the trajectory.
traj_1 = sr.load(
[str(Path(tmpdir) / "system.prm7"), str(Path(tmpdir) / "traj_0.dcd")]
)

# Check that both config and lambda have been written
# as properties to the streamed checkpoint file.
checkpoint = sr.stream.load(str(Path(tmpdir) / "checkpoint_0.s3"))
Expand Down Expand Up @@ -79,11 +84,17 @@ def test_restart(mols, request):

# Check that first half of energy trajectory is the same
assert energy_traj_1.equals(energy_traj_2.iloc[:num_entries])

# Check that second energy trajectory is twice as long as the first
assert len(energy_traj_2.index) == 2 * num_entries
# Check that a second trajectory was written and that the first still exists
assert Path.exists(Path(tmpdir) / "traj_0.dcd")
assert Path.exists(Path(tmpdir) / "traj_0_1.dcd")

# Reload the trajectory.
traj_2 = sr.load(
[str(Path(tmpdir) / "system.prm7"), str(Path(tmpdir) / "traj_0.dcd")]
)

# Check that the trajectory is twice as long as the first.
assert traj_2.num_frames() == 2 * traj_1.num_frames()

config_difftimestep = config_new.copy()
config_difftimestep["runtime"] = "36fs"
Expand Down

0 comments on commit 308f469

Please sign in to comment.