Skip to content

Commit

Permalink
#474 Refactored mem_use.py, changed memory allocation in read.py (flo…
Browse files Browse the repository at this point in the history
…at32)
  • Loading branch information
N720720 committed Jun 7, 2024
1 parent 697773d commit 60ff6c6
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 25 deletions.
33 changes: 26 additions & 7 deletions lindemann/index/mem_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,34 @@
import numpy.typing as npt


def in_gb(frames: npt.NDArray[np.float64]) -> str:
"""Shows the size of the array in memory in GB.
def in_gb(nframes: int, natoms: int) -> str:
"""
Calculates and shows the size of the memory allocations related to
the different flag options in gigabytes (GB).
Args:
frames (npt.NDArray[np.float64]): numpy array of shape(frames,atoms)
nframes (int): The number of frames in the trajectory.
natoms (int): The number of atoms per frame in the trajectory.
Returns:
str: Size of array in GB.
str: A formatted string containing the memory usage for different configurations:
- per_trj: Memory required when the `-t` flag is used.
- per_frames: Memory required when the `-f` flag is used.
- per_atoms: Memory required when the `-a` flag is used.
This function assumes memory calculations based on numpy's float32 data type.
"""
natoms = len(frames[0])
nframes = len(frames)
return f"This will use {np.round((np.zeros((natoms, natoms)).nbytes/1024**3),4)} GB" # type: ignore[no-untyped-call]

num_distances = natoms * (natoms - 1) // 2
float_size = np.float32().nbytes
trj = nframes * natoms * 3 * float_size
atom_atom_array = 3 * natoms * natoms * float_size
atom_array = natoms * float_size
linde_index = nframes * natoms * float_size
sum_bytes = trj + atom_atom_array + atom_array + linde_index
per_trj = (
f"\nFlag -t (per_trj) will use {np.round((trj+num_distances*2*float_size)/1024**3,4)} GB\n"
)
per_frames = f"Flag -f (per_frames) will use {np.round((trj+(num_distances*2*float_size)+(nframes*float_size))/1024**3,4)} GB\n"
per_atoms = f"Flag -a (per_atoms) will use {np.round(sum_bytes/1024**3,4)} GB"
return f"{per_trj}{per_frames}{per_atoms}"
7 changes: 3 additions & 4 deletions lindemann/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def main(
typer.Exit()

elif timeit and single_process:
# we use float32 here since float64 is not needed for my purposes and it enables us to use nb fastmath. Change to np.float64 if you need more precision.
start = time.time()
linde_for_time = per_trj.calculate(tjr_frames.astype(np.float32))
linde_for_time = per_trj.calculate(tjr_frames)
time_diff = time.time() - start

console.print(
Expand All @@ -170,8 +169,8 @@ def main(
typer.Exit()

elif mem_useage and single_process:

mem_use_in_gb = mem_use.in_gb(tjr_frames)
nframes, natoms, _ = tjr_frames.shape
mem_use_in_gb = mem_use.in_gb(nframes, natoms)

console.print(f"[magenta]memory use:[/] [bold blue]{mem_use_in_gb}[/]")
typer.Exit()
Expand Down
39 changes: 26 additions & 13 deletions lindemann/trajectory/read.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
from typing import Optional

import os

import numpy as np
import numpy.typing as npt
from ovito.io import import_file
from ovito.modifiers import SelectTypeModifier


def frames(trjfile: str, nframes: Optional[int] = None) -> npt.NDArray[np.float64]:
"""
Get the frames from the lammps trajectory using ovito pipeline and import_file function.
It returns frames and the number of frames to use for calculating the Lindemann Index.
def frames(trjfile: str, nframes: Optional[int] = None) -> npt.NDArray[np.float32]:
"""
Extracts the frame position data from a MD trajectory file using the OVITO pipeline.
The function loads the specified trajectory file, applies a selection modifier to filter
particles of type 1, 2, and 3, and computes the positions for a specified number of frames.
If `nframes` is None, the function will attempt to process all frames in the trajectory.
Parameters:
trjfile (str): Path to the trajectory file to be processed.
nframes (Optional[int]): The number of frames to process. If not specified, all frames
in the trajectory file are processed. If the specified number
exceeds the available frames in the file, a ValueError is raised.
if not os.path.exists(trjfile):
raise RuntimeError(f"Error: file {trjfile} not found!")
Returns:
npt.NDArray[np.float32]: A 3D NumPy array of shape (nframes, num_particles, 3) containing
the position data for each particle across the specified frames.
Raises:
ValueError: If `nframes` is more than the number of available frames in the trajectory file.
Example:
>>> positions = frames("path/to/trajectory.lammpstrj", 100)
This would load 100 frames from the specified file and return the position data.
"""

pipeline = import_file(trjfile, sort_particles=True)
num_frame = pipeline.source.num_frames
Expand All @@ -25,14 +40,12 @@ def frames(trjfile: str, nframes: Optional[int] = None) -> npt.NDArray[np.float6
data = pipeline.compute()
num_particle = data.particles.count

# If no argument is given use all frames
if nframes is None:
nframes = num_frame
# make sure nobody puts more frames then exists
assert num_frame >= nframes
elif nframes > num_frame:
raise ValueError(f"Requested {nframes} frames, but only {num_frame} frames are available.")

# initialise array, could be problematic for big clusters and a lot of frames
position = np.zeros((nframes, num_particle, 3))
position = np.zeros((nframes, num_particle, 3), dtype=np.float32)

for frame in range(nframes):
data = pipeline.compute(frame)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_example/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_t_flag():

def test_m_flag():
flag = "-m"
res_str = "memory use: This will use 0.0016 GB"
res_str = "memory use: \nFlag -t (per_trj) will use 0.0034 GB\nFlag -f (per_frames) will use 0.0034 GB\nFlag -a (per_atoms) will use 0.0058 GB\n"
trajectory = ["tests/test_example/459_02.lammpstrj"]
single_process_and_multiprocess(trajectory, flag, res_str)

Expand Down

0 comments on commit 60ff6c6

Please sign in to comment.