Skip to content

Commit

Permalink
Merge branch 'main' into pdb_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyAJohnston authored Nov 26, 2024
2 parents 81d6e80 + a1fd75f commit c7d0b42
Show file tree
Hide file tree
Showing 7 changed files with 849 additions and 133 deletions.
6 changes: 3 additions & 3 deletions molecularnodes/bpyd/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,14 +428,14 @@ def list_attributes(
A list of attribute names if the molecule object exists, None otherwise.
"""
if evaluate:
strings = list(self.evaluate().attributes.keys())
strings = list(self.evaluate().object.data.attributes.keys())
else:
strings = list(self.object.attributes.keys())
strings = list(self.object.data.attributes.keys())

if not drop_hidden:
return strings
else:
return filter(lambda x: not x.startswith("."), strings)
return [x for x in strings if not x.startswith(".")]

def __len__(self) -> int:
"""
Expand Down
274 changes: 178 additions & 96 deletions molecularnodes/entities/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
import MDAnalysis as mda
import numpy as np
import numpy.typing as npt
from math import floor, remainder

from ... import data
from ..entity import MolecularEntity
from ...blender import coll, nodes, path_resolve
from ... import bpyd
from ...utils import correct_periodic_positions
from ...utils import (
correct_periodic_positions,
frame_mapper,
frames_to_average,
fraction,
)
from .selections import Selection, TrajectorySelectionItem


Expand All @@ -21,6 +27,7 @@ def __init__(self, universe: mda.Universe, world_scale: float = 0.01):
self.calculations: Dict[str, Callable] = {}
self.world_scale = world_scale
self.frame_mapping: npt.NDArray[np.in64] | None = None
self.cache: dict = {}
bpy.context.scene.MNSession.trajectories[self.uuid] = self

def selection_from_ui(self, ui_item: TrajectorySelectionItem) -> Selection:
Expand Down Expand Up @@ -79,45 +86,6 @@ def apply_selection(self, selection: Selection):
"Set the boolean attribute for this selection on the mesh of the object"
self.set_boolean(selection.to_mask(), name=selection.name)

@property
def subframes(self):
obj = self.object
if obj is None:
return None
return obj.mn.subframes

@subframes.setter
def subframes(self, value: int):
obj = self.object
if obj is None:
return None
obj.mn.subframes = value

@property
def offset(self) -> int:
try:
return self.object.mn.offset
except AttributeError:
return None

@offset.setter
def offset(self, value: int):
self.object.mn.offset = value

@property
def interpolate(self) -> bool:
obj = self.object
if obj is None:
return None
return obj.mn.interpolate

@interpolate.setter
def interpolate(self, value: bool):
obj = self.object
if obj is None:
return None
obj.mn.interpolate = value

@property
def is_orthorhombic(self):
dim = self.universe.dimensions
Expand Down Expand Up @@ -202,18 +170,40 @@ def mass(self) -> np.ndarray:
]
return np.array(masses)

@property
def n_frames(self) -> int:
return self.universe.trajectory.n_frames

@property
def res_id(self) -> np.ndarray:
return self.atoms.resnums

@property
def frame(self) -> int:
def uframe(self) -> int:
"""
Get the current frame number of the linked `Universe.trajectory`.
Returns:
int: Current frame number in the trajectory.
"""
return self.universe.trajectory.frame

@frame.setter
def frame(self, value) -> None:
@uframe.setter
def uframe(self, value) -> None:
"""
Set the current frame number of the linked `Universe.trajectory`.
The frame number is clamped between 0 and n_frames-1 to prevent
out-of-bounds access.
Args:
value (int): Target frame number to set.
Returns:
None
"""
if self.universe.trajectory.frame != value:
self.universe.trajectory[value]
self.universe.trajectory[max(min(value, self.n_frames - 1), 0)]

@property
def res_name(self) -> np.ndarray:
Expand Down Expand Up @@ -516,71 +506,163 @@ def _update_selections(self):
except Exception as e:
print(e)

def _update_positions(self, frame):
"""
The function that will be called when the frame changes.
It will update the positions and selections of the atoms in the scene.
"""
universe = self.universe
frame_mapping = self.frame_mapping
obj = self.object
@property
def subframes(self) -> int:
return self.object.mn.subframes

subframes: int = obj.mn.subframes
interpolate: bool = obj.mn.interpolate
offset: int = obj.mn.offset

# we subtraect the offset, a negative offset value ensures that the trajectory starts
# playback that many frames before 0 and a positive value ensures we start the
# playback after 0
frame -= offset
# for actually getting frames from the trajectory we need to clamp it to a lower
# bound of 0 which will be the start frame for the trajectory
frame = max(frame, 0)

if frame_mapping:
# add the subframes to the frame mapping
frame_map = np.repeat(frame_mapping, subframes + 1)
# get the current and next frames
frame_a = frame_map[frame]
frame_b = frame_map[frame + 1]
@subframes.setter
def subframes(self, value: int) -> None:
try:
self.object.mn.subframes = value
except AttributeError:
raise bpyd.object.ObjectMissingError(
"Trajectory does not have a linked object. Cannot get subframes related to this object."
)

else:
# get the initial frame
if subframes == 0:
frame_a = frame
else:
frame_a = int(frame / (subframes + 1))
@property
def offset(self) -> int:
return self.object.mn.offset

# get the next frame
frame_b = frame_a + 1
@offset.setter
def offset(self, value: int) -> None:
self.object.mn.offset = value

if frame_a >= universe.trajectory.n_frames:
return None
@property
def average(self) -> int:
return self.object.mn.average

# set the trajectory at frame_a
self.frame = frame_a
@average.setter
def average(self, value: int) -> None:
self.object.mn.average = value

if subframes > 0 and interpolate:
fraction = frame % (subframes + 1) / (subframes + 1)
@property
def correct_periodic(self) -> bool:
return self.object.mn.correct_periodic

# get the positions for the next frame
positions_a = self.univ_positions
@correct_periodic.setter
def correct_periodic(self, value: bool) -> None:
self.object.mn.correct_periodic = value

if frame_b < universe.trajectory.n_frames:
self.frame = frame_b
positions_b = self.univ_positions
@property
def interpolate(self) -> bool:
return self.object.mn.interpolate

if obj.mn.correct_periodic and self.is_orthorhombic:
positions_b = correct_periodic_positions(
positions_a,
positions_b,
dimensions=universe.dimensions[:3] * self.world_scale,
@interpolate.setter
def interpolate(self, value: bool) -> None:
self.object.mn.interpolate = value

def _frame_range(self, frame: int):
"Get the trajectory frame numbers over which we will average values"
return frames_to_average(frame, self.average)

def _cache_ordered(self) -> np.ndarray:
"Return the cached frames as a 3D array, in chronological order"
keys = list(self.cache.keys())
keys.sort()
return np.array([self.cache[k] for k in keys])

def adjust_periodic_positions(
self, pos1: np.ndarray, pos2: np.ndarray
) -> np.ndarray:
"Returns the input pos2 with a periodic correction potentially applied"
if self.correct_periodic and self.is_orthorhombic:
return correct_periodic_positions(pos1, pos2, self.universe.dimensions[:3])
else:
return pos2

def position_cache_mean(self, frame: int) -> np.ndarray:
"Return the mean position from the currently cached positions"
self.update_position_cache(frame)

if self.average == 0:
return self.cache[frame]

array = self._cache_ordered()
if self.correct_periodic and self.is_orthorhombic:
# we want to correct the periodic boundary crossing in refernce to the fist
# frame we are averaging
for i, pos in enumerate(array):
if i == 0:
continue
array[i] = self.adjust_periodic_positions(array[0], pos)

return np.mean(array, axis=0)

def _position_at_frame(self, frame: int) -> np.ndarray:
"Return the atom positions at the given universe frame number"
self.uframe = frame
return self.univ_positions

def update_position_cache(self, frame: int, cache_ahead: bool = True) -> None:
"Update the currently cached positions, based on the new frame"
# get the individual frame numbers that we will be caching
frames_to_cache = self._frame_range(frame)

# if we should be looking ahead by 1 for interpolating, ensure we are caching 1
# frame ahead so when the frame changes we already have it stored and aren't
# double dipping
if len(frames_to_cache) == 1 and cache_ahead:
frames_to_cache = np.array(
(frames_to_cache[0], frames_to_cache[0] + 1), dtype=int
)

# only cleanup the cache if we have more than 2 frame stored, helps when moving
# forward or back a single frame
if len(self.cache) > 2:
# remove any frames that no longer need to be cached
to_remove = [f for f in self.cache if f not in frames_to_cache]
for f in to_remove:
del self.cache[f]

# update the cache with any frames that are not yet cached
for f in frames_to_cache:
if f not in self.cache:
self.cache[f] = self._position_at_frame(f)

def frame_mapper(self, frame: int):
return frame_mapper(
frame=frame,
subframes=self.subframes,
offset=self.offset,
mapping=self.frame_mapping,
)

def _update_positions(self, frame):
"""
The function that will be called when the frame changes.
It will update the positions and selections of the atoms in the scene.
"""
# get the two frames of the trajectory to potentially access data from
# uframe_current, uframe_next = [self.frame_mapper(x) for x in (frame, frame + 1)]
uframe_current = self.frame_mapper(frame)
uframe_next = uframe_current + 1

if self.subframes > 0 and self.interpolate:
# if we are adding subframes and interpolating, then we get the positions
# at the two universe frames, then interpolate between them, potentially
# correcting for any periodic boundary crossing
pos_current = self.position_cache_mean(
uframe_current,
)
pos_next = self.position_cache_mean(uframe_next)

# if we are averaging, then we have already applied periodic correction
# and we can skip this step
if self.correct_periodic and self.is_orthorhombic and self.average == 0:
pos_next = correct_periodic_positions(
pos_current,
pos_next,
dimensions=self.universe.dimensions[:3] * self.world_scale,
)

# interpolate between the two sets of positions
self.position = bpyd.lerp(positions_a, positions_b, t=fraction)
self.position = bpyd.lerp(
pos_current, pos_next, t=fraction(frame, self.subframes + 1)
)
else:
self.position = self.univ_positions
# otherwise just get the current positions for the relevant frame and set
# those on the object
self.position = self.position_cache_mean(uframe_current)

def __repr__(self):
return f"<Trajectory, `universe`: {self.universe}, `object`: {self.object}"
13 changes: 11 additions & 2 deletions molecularnodes/props.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class MolecularNodesObjectProperties(bpy.types.PropertyGroup):
description="Number of subframes to insert between frames of the loaded trajectory",
default=0,
update=_update_trajectories,
min=0,
)
offset: IntProperty( # type: ignore
name="Offset",
Expand All @@ -88,10 +89,18 @@ class MolecularNodesObjectProperties(bpy.types.PropertyGroup):
default=True,
update=_update_trajectories,
)
average: IntProperty( # type: ignore
name="Average",
description="Average the position this number of frames either side of the current frame",
default=0,
update=_update_trajectories,
min=0,
soft_max=5,
)
correct_periodic: BoolProperty( # type: ignore
name="Correct",
description="Correct for periodic boundary crossing when using interpolation. Assumes cubic dimensions",
default=True,
description="Correct for periodic boundary crossing when using interpolation or averaging. Assumes cubic dimensions and only works if the unit cell is orthorhombic",
default=False,
update=_update_trajectories,
)
filepath_trajectory: StringProperty( # type: ignore
Expand Down
Loading

0 comments on commit c7d0b42

Please sign in to comment.