Skip to content

Commit

Permalink
Merge pull request #67 from Becksteinlab/simplify-cache
Browse files Browse the repository at this point in the history
Simplify cache
  • Loading branch information
ljwoods2 authored Oct 3, 2024
2 parents d57e287 + bda675d commit 6b35db8
Showing 1 changed file with 50 additions and 166 deletions.
216 changes: 50 additions & 166 deletions zarrtraj/ZARR.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ class ZARRH5MDReader(base.ReaderBase):
@due.dcite(
Doi("10.1002/jcc.21787"), description="MDAnalysis 2011", path=__name__
)
@due.dcite(Doi("10.5281/zenodo.3773449"), description="Zarr", path=__name__)
@due.dcite(
Doi("10.5281/zenodo.3773449"), description="Zarr", path=__name__
)
@store_init_arguments
def __init__(
self,
Expand Down Expand Up @@ -313,8 +315,7 @@ def __init__(
# Set to none so close() can be called
self._file = None
self._cache = None
# Read first timestep
self._frame_seq = collections.deque([0])

if not HAS_ZARR:
raise RuntimeError("Please install zarr")
super(ZARRH5MDReader, self).__init__(filename, **kwargs)
Expand Down Expand Up @@ -397,7 +398,7 @@ def __init__(
self._global_steparray,
self._stepmaps,
)
self._cache.update_frame_seq(self._frame_seq)

self._read_next_timestep()

def _set_translated_units(self):
Expand Down Expand Up @@ -628,22 +629,14 @@ def _read_next_timestep(self):

def _read_frame(self, frame):
"""reads data from h5md-formatted file and copies to current timestep"""
# frame seq update case 1: read called from iterator-like context
if not self._frame_seq:
self._frame_seq = None
self._cache.update_frame_seq(self._frame_seq)
raise StopIteration
if frame < 0 or frame >= self.n_frames:
raise IOError("Frame index out of range")

self._frame = self._cache.load_frame()
self._frame = self._cache.load_frame(frame)

if self.convert_units:
self._convert_units()

# frame seq update case 2: read called from __getitem__-like context
if len(self._frame_seq) == 0:
self._frame_seq = None
self._cache.update_frame_seq(self._frame_seq)

return self.ts

def _convert_units(self):
Expand All @@ -667,7 +660,6 @@ def _convert_units(self):

def close(self):
"""close reader"""
self._frame_seq = None
if self._cache is not None:
self._cache.cleanup()
if self._file is not None:
Expand All @@ -687,150 +679,6 @@ def Writer(self, filename, n_atoms=None, **kwargs):
kwargs.setdefault("forces", ("force" in self._elements))
return ZARRMDWriter(filename, n_atoms, **kwargs)

def __getitem__(self, frame):
"""Return the Timestep corresponding to *frame*.
If `frame` is a integer then the corresponding frame is
returned. Negative numbers are counted from the end.
If frame is a :class:`slice` then an iterator is returned that
allows iteration over that part of the trajectory.
Note
----
*frame* is a 0-based frame index.
Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
if isinstance(frame, numbers.Integral):
frame = self._apply_limits(frame)
if self._frame_seq is None:
self._frame_seq = collections.deque([frame])
self._cache.update_frame_seq(self._frame_seq)
return self._read_frame_with_aux(frame)
elif isinstance(frame, (list, np.ndarray)):
if len(frame) != 0 and isinstance(frame[0], (bool, np.bool_)):
# Avoid having list of bools
frame = np.asarray(frame, dtype=bool)
# Convert bool array to int array
frame = np.arange(len(self))[frame]
if isinstance(frame, np.ndarray):
frame = frame.tolist()
if self._frame_seq is None:
self._frame_seq = collections.deque(frame)
self._cache.update_frame_seq(self._frame_seq)
return base.FrameIteratorIndices(self, frame)
elif isinstance(frame, slice):
start, stop, step = self.check_slice_indices(
frame.start, frame.stop, frame.step
)
if self._frame_seq is None:
self._frame_seq = collections.deque(range(start, stop, step))
self._cache.update_frame_seq(self._frame_seq)
if start == 0 and stop == len(self) and step == 1:
return base.FrameIteratorAll(self)
else:
return base.FrameIteratorSliced(self, frame)
else:
raise TypeError(
"Trajectories must be an indexed using an integer,"
" slice or list of indices"
)

def __iter__(self):
"""Iterate over all frames in the trajectory
Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
self._reopen()
self._frame_seq = collections.deque(range(0, self.n_frames))
self._cache.update_frame_seq(self._frame_seq)
return self

def next(self):
if self._frame_seq is None and self._frame + 1 < self.n_frames:
self._frame_seq = collections.deque([self._frame + 1])
self._cache.update_frame_seq(self._frame_seq)
elif self._frame_seq is None:
self.rewind()
raise StopIteration from None
try:
ts = self._read_next_timestep()
except (EOFError, IOError):
self.rewind()
raise StopIteration from None
else:
for auxname, reader in self._auxs.items():
ts = self._auxs[auxname].update_ts(ts)

ts = self._apply_transformations(ts)

return ts

def iter_as_aux(self, auxname):
"""Iterate over the trajectory with an auxiliary reader
Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
"""
aux = self._check_for_aux(auxname)
self._reopen()
self._frame_seq = collections.deque(range(0, self.n_frames))
self._cache.update_frame_seq(self._frame_seq)
aux._restart()
while True:
try:
yield self.next_as_aux(auxname)
except StopIteration:
return

def copy(self):
"""Return independent copy of this Reader.
New Reader will have its own file handle and can seek/iterate
independently of the original.
Will also copy the current state of the Timestep held in the original
Reader.
Note
----
ZARRH5MDReader overrides this method to get
access to the the sequence of frames
the user wants.
.. versionchanged:: 2.2.0
Arguments used to construct the reader are correctly captured and
passed to the creation of the new class. Previously the only
``n_atoms`` was passed to class copies, leading to a class created
with default parameters which may differ from the original class.
"""

new = self.__class__(**self._kwargs)

if self.transformations:
new.add_transformations(*self.transformations)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
new._cache._timestep = new.ts
for auxname, auxread in self._auxs.items():
new.add_auxiliary(auxname, auxread.copy())
return new

@property
def n_frames(self):
"""number of frames in trajectory"""
Expand Down Expand Up @@ -870,6 +718,41 @@ def parse_n_atoms(filename, group=None, so=None):
"You must include a topology file."
)

def copy(self):
"""Return independent copy of this Reader.
New Reader will have its own file handle and can seek/iterate
independently of the original.
Will also copy the current state of the Timestep held in the original
Reader.
Note
----
ZARRH5MDReader overrides this method to copy
the copied reader's timestep to the cache's timestep
.. versionchanged:: 2.2.0
Arguments used to construct the reader are correctly captured and
passed to the creation of the new class. Previously the only
``n_atoms`` was passed to class copies, leading to a class created
with default parameters which may differ from the original class.
"""

new = self.__class__(**self._kwargs)

if self.transformations:
new.add_transformations(*self.transformations)
# seek the new reader to the same frame we started with
new[self.ts.frame]
# then copy over the current Timestep in case it has
# been modified since initial load
new.ts = self.ts.copy()
new._cache._timestep = new.ts
for auxname, auxread in self._auxs.items():
new.add_auxiliary(auxname, auxread.copy())
return new


class H5MDElementBuffer:
def __init__(
Expand Down Expand Up @@ -996,9 +879,9 @@ def flush(self):
if num_v_frames == 0:
num_v_frames = self._val_frames_per_chunk

self._val[self._val_idx - num_v_frames : self._val_idx] = self._val_buf[
:num_v_frames
]
self._val[self._val_idx - num_v_frames : self._val_idx] = (
self._val_buf[:num_v_frames]
)
self._val.resize(self._val_idx, *self._val_chunks[1:])

num_t_frames = self._t_idx % self._t_frames_per_chunk
Expand Down Expand Up @@ -1248,7 +1131,9 @@ def __init__(

protocol = get_protocol(filename)
if protocol not in ZARRTRAJ_NETWORK_PROTOCOLS and protocol != "file":
raise ValueError(f"Unsupported protocol '{protocol}' for Zarrtraj.")
raise ValueError(
f"Unsupported protocol '{protocol}' for Zarrtraj."
)
if protocol in ZARRTRAJ_EXPERIMENTAL_PROTOCOLS:
warnings.warn(
f"Zarrtraj is using the experimental protocol '{protocol}' "
Expand Down Expand Up @@ -1649,9 +1534,8 @@ def update_desired_dsets(
self._global_steparray = global_steparray
self._stepmaps = stepmaps

def load_frame(self):
def load_frame(self, frame):
"""Reader responsible for raising StopIteration when no more frames"""
frame = self._frame_seq.popleft()
self._load_timestep_frame(frame)
return frame

Expand Down

0 comments on commit 6b35db8

Please sign in to comment.