Skip to content

Commit

Permalink
Merge pull request #16 from maurergroup/parameter_enhancements
Browse files Browse the repository at this point in the history
- Added tests for parameters and base_parser modules 
- Enhancements to parameters and base_parser 
- Minor fixes elsewhere
  • Loading branch information
dylanbmorgan authored Nov 22, 2024
2 parents 39bfaa0 + 48cd901 commit 692b0b1
Show file tree
Hide file tree
Showing 27 changed files with 1,813 additions and 306 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pypi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
environment:
name: pypi
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
# url: https://pypi.org/p/YOURPROJECT
url: https://pypi.org/project/dfttoolkit/

steps:
- name: Retrieve release distributions
Expand Down
38 changes: 38 additions & 0 deletions dfttoolkit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# from .benchmarking import BenchmarkAims
# from .custom_ase import CustomAims
# from .friction import FrictionTensor
# from .geometry import AimsGeometry, VaspGeometry, XSFGeometry, XYZGeometry
# from .output import AimsOutput, ELSIOutput
# from .parameters import AimsControl
# from .trajectory import MDTrajectory
# from .utils import math_utils, units
# from .utils.file_utils import aims_bin_path_prompt, check_required_files
# from .utils.geometry_utils import read_xyz_animation
# from .utils.periodic_table import PeriodicTable
# from .utils.run_utils import no_repeat
# from .vibrations import AimsVibrations, VaspVibrations
# from .visualise import VisualiseAims

# __all__ = [
# "BenchmarkAims",
# "CustomAims",
# "FrictionTensor",
# "AimsGeometry",
# "VaspGeometry",
# "XSFGeometry",
# "XYZGeometry",
# "AimsOutput",
# "ELSIOutput",
# "AimsControl",
# "MDTrajectory",
# "math_utils",
# "units",
# "aims_bin_path_prompt",
# "check_required_files",
# "read_xyz_animation",
# "PeriodicTable",
# "no_repeat",
# "AimsVibrations",
# "VaspVibrations",
# "VisualiseAims",
# ]
2 changes: 1 addition & 1 deletion dfttoolkit/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, supported_files, **kwargs):

# Check if the file path exists
if not os.path.isfile(kwargs[kwarg]):
raise ValueError(f"{kwargs[kwarg]} does not exist.")
raise FileNotFoundError(f"{kwargs[kwarg]} does not exist.")

# Store the file paths
self.file_paths[kwarg] = kwargs[kwarg]
Expand Down
18 changes: 9 additions & 9 deletions dfttoolkit/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,33 +1323,33 @@ def read_elsi_as_csc(

# Get the column pointer
end = 128 + self.n_basis * 8
col_ptr = np.frombuffer(self.lines[128:end], dtype=np.int64)
col_ptr = np.append(col_ptr, self.n_non_zero + 1)
col_ptr -= 1
col_i = np.frombuffer(self.lines[128:end], dtype=np.int64)
col_i = np.append(col_i, self.n_non_zero + 1)
col_i -= 1

# Get the row index
start = end + self.n_non_zero * 4
row_idx = np.array(np.frombuffer(self.lines[end:start], dtype=np.int32))
row_idx -= 1
row_i = np.array(np.frombuffer(self.lines[end:start], dtype=np.int32))
row_i -= 1

if header[2] == 0: # real
nnz_val = np.frombuffer(
nnz = np.frombuffer(
self.lines[start : start + self.n_non_zero * 8],
dtype=np.float64,
)

else: # complex
nnz_val = np.frombuffer(
nnz = np.frombuffer(
self.lines[start : start + self.n_non_zero * 16],
dtype=np.complex128,
)

if csc_format:
return sp.csc_matrix(
(nnz_val, row_idx, col_ptr), shape=(self.n_basis, self.n_basis)
(nnz, row_i, col_i), shape=(self.n_basis, self.n_basis)
)

else:
return sp.csc_matrix(
(nnz_val, row_idx, col_ptr), shape=(self.n_basis, self.n_basis)
(nnz, row_i, col_i), shape=(self.n_basis, self.n_basis)
).toarray()
27 changes: 23 additions & 4 deletions dfttoolkit/parameters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal, Union

import dfttoolkit.utils.file_utils as fu
from dfttoolkit.base_parser import BaseParser
Expand Down Expand Up @@ -75,23 +75,42 @@ def add_keywords(self, **kwargs: dict) -> None:
# TODO finish this
raise NotImplementedError

def remove_keywords(self, *args: str) -> None:
def remove_keywords(
self, *args: str, output: Literal["overwrite", "print", "return"] = "return"
) -> Union[None, List[str]]:
"""
Remove keywords from the control.in file.
Parameters
----------
*args : str
Keywords to be removed from the control.in file.
output : Literal["overwrite", "print", "return"], default="overwrite"
Overwrite the original file, print the modified file to STDOUT, or return
the modified file as a list of '\\n' separated strings.
Returns
-------
Union[None, List[str]]
If output is "return", the modified file is returned as a list of '\\n'
separated strings.
"""

for keyword in args:
for i, line in enumerate(self.lines):
if keyword in line:
self.lines.pop(i)

with open(self.path, "w") as f:
f.writelines(self.lines)
match output:
case "overwrite":
with open(self.path, "w") as f:
f.writelines(self.lines)

case "print":
print(*self.lines, sep="")

case "return":
return self.lines

def get_keywords(self) -> dict:
"""
Expand Down
Empty file removed dfttoolkit/utils/__init__.py
Empty file.
9 changes: 5 additions & 4 deletions dfttoolkit/utils/math_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union
from copy import deepcopy
from typing import Union

import numpy as np
import numpy.typing as npt
import scipy
Expand Down Expand Up @@ -405,7 +406,7 @@ def apply_hann_window(data):
return windowed_data


def norm_matrix_by_dagonal(matrix: np.array) -> np.array:
def norm_matrix_by_dagonal(matrix: npt.NDArray) -> npt.NDArray:
"""
Norms a matrix such that the diagonal becomes 1.
Expand Down Expand Up @@ -437,7 +438,7 @@ def norm_matrix_by_dagonal(matrix: np.array) -> np.array:
return new_matrix


def mae(delta: np.ndarray) -> float:
def mae(delta: np.ndarray) -> np.floating:
"""
Calculated the mean absolute error from a list of value differnces.
Expand All @@ -455,7 +456,7 @@ def mae(delta: np.ndarray) -> float:
return np.mean(np.abs(delta))


def rel_mae(delta: np.ndarray, target_val: np.ndarray) -> float:
def rel_mae(delta: np.ndarray, target_val: np.ndarray) -> np.floating:
"""
Calculated the relative mean absolute error from a list of value differnces,
given the target values.
Expand Down
14 changes: 10 additions & 4 deletions dfttoolkit/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ def no_repeat(
def _no_repeat(func):
@wraps(func)
def wrapper(*args, **kwargs):
if not os.path.isdir(calc_dir):
raise ValueError(f"{calc_dir} is not a directory.")
# Override calc_dir in decorator call if given in func
if "calc_dir" in kwargs:
check_dir = kwargs["calc_dir"]
else:
check_dir = calc_dir

if not os.path.isdir(check_dir):
raise ValueError(f"{check_dir} is not a directory.")
if force:
return func(*args, **kwargs)
if not os.path.isfile(f"{calc_dir}/{output_file}"):
if not os.path.isfile(f"{check_dir}/{output_file}"):
return func(*args, **kwargs)
else:
print(f"aims.out already exists in {calc_dir}. Skipping calculation.")
print(f"aims.out already exists in {check_dir}. Skipping calculation.")

return wrapper

Expand Down
Loading

0 comments on commit 692b0b1

Please sign in to comment.