Skip to content

Commit

Permalink
Add missing save_energy_components method.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 30, 2024
1 parent 7ee642c commit 19ad761
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
49 changes: 49 additions & 0 deletions src/somd2/runner/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def __init__(self, system, config):
self._config.checkpoint_frequency / self._config.energy_frequency
)

# Zero the energy sample.
self._nrg_sample = 0

# Create the default dynamics kwargs dictionary. These can be overloaded
# as needed.
self._dynamics_kwargs = {
Expand Down Expand Up @@ -1159,3 +1162,49 @@ def _checkpoint(
self._parquet,
df.iloc[-self._energy_per_block :],
)

def _save_energy_components(self, context):
"""
Internal function to save the energy components for each force group to file.
Parameters
----------
context : :class: `Context <openmm.Context>`
The current OpenMM context.
"""

from copy import deepcopy
import openmm

# Get the current context and system.
system = deepcopy(context.getSystem())

# Add each force to a unique group.
for i, f in enumerate(system.getForces()):
f.setForceGroup(i)

# Create a new context.
new_context = openmm.Context(system, deepcopy(context.getIntegrator()))
new_context.setPositions(context.getState(getPositions=True).getPositions())

header = f"{'# Sample':>10}"
record = f"{self._nrg_sample:>10}"

# Process the records.
for i, f in enumerate(system.getForces()):
state = new_context.getState(getEnergy=True, groups={i})
header += f"{f.getName():>25}"
record += f"{state.getPotentialEnergy().value_in_unit(openmm.unit.kilocalories_per_mole):>25.2f}"

# Write to file.
if self._nrg_sample == 0:
with open(self._filenames["energy_components"], "w") as f:
f.write(header + "\n")
f.write(record + "\n")
else:
with open(self._filenames["energy_components"], "a") as f:
f.write(record + "\n")

# Increment the sample number.
self._nrg_sample += 1
4 changes: 2 additions & 2 deletions src/somd2/runner/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def generate_lam_vals(lambda_base, increment=0.001):
try:
# Save the energy contribution for each force.
if self._config.save_energy_components:
self._save_energy_components()
self._save_energy_components(dynamics.context())

# Commit the current system.
system = dynamics.commit()
Expand Down Expand Up @@ -539,7 +539,7 @@ def generate_lam_vals(lambda_base, increment=0.001):

# Save the energy contribution for each force.
if self._config.save_energy_components:
self._save_energy_components()
self._save_energy_components(dynamics.context())

# Commit the current system.
system = dynamics.commit()
Expand Down

0 comments on commit 19ad761

Please sign in to comment.