Skip to content

Commit

Permalink
Merge pull request #344 from tovrstra/prepare-dump-light
Browse files Browse the repository at this point in the history
Add light version of prepare_dump API
  • Loading branch information
tovrstra authored Jun 20, 2024
2 parents 07421c3 + e09e8a3 commit c74af1a
Show file tree
Hide file tree
Showing 16 changed files with 265 additions and 91 deletions.
23 changes: 15 additions & 8 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import Callable, Optional

from .iodata import IOData
from .utils import FileFormatError, LineIterator
from .utils import FileFormatError, LineIterator, PrepareDumpError

__all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"]

Expand Down Expand Up @@ -185,12 +185,12 @@ def _check_required(iodata: IOData, dump_func: Callable):
Raises
------
FileFormatError
PrepareDumpError
When a required attribute is ``None``.
"""
for attr_name in dump_func.required:
if getattr(iodata, attr_name) is None:
raise FileFormatError(
raise PrepareDumpError(
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None."
)

Expand All @@ -216,11 +216,15 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
Raises
------
FileFormatError
When one of the iodata items does not have the required attributes.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
"""
format_module = _select_format_module(filename, "dump_one", fmt)
_check_required(iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
with open(filename, "w") as f:
format_module.dump_one(f, iodata, **kwargs)

Expand All @@ -245,9 +249,10 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
Raises
------
FileFormatError
When iodatas has zero length
or when one of the iodata items does not have the required attributes.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
to make it compatible.
"""
format_module = _select_format_module(filename, "dump_many", fmt)

Expand All @@ -267,6 +272,8 @@ def checking_iterator():
yield first
for other in iter_iodatas:
_check_required(other, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(other)
yield other

with open(filename, "w") as f:
Expand Down
2 changes: 1 addition & 1 deletion iodata/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class MolecularBasis:
shells: list[Shell] = attrs.field()
"""A list of objects of type Shell which can support generalized contractions."""

conventions: dict[str, str] = attrs.field()
conventions: dict[tuple[int, str], list[str]] = attrs.field()
"""
A dictionary specifying the ordered basis functions for a given angular momentum and kind.
The key is a tuple of angular momentum integer and kind character ('c' for Cartesian
Expand Down
35 changes: 26 additions & 9 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ..docstrings import document_dump_one, document_load_many, document_load_one
from ..iodata import IOData
from ..orbitals import MolecularOrbitals
from ..utils import LineIterator, amu
from ..utils import LineIterator, PrepareDumpError, amu

__all__ = []

Expand Down Expand Up @@ -542,6 +542,31 @@ def _dump_real_arrays(name: str, val: NDArray[float], f: TextIO):
k = 0


def prepare_dump(data: IOData):
"""Check the compatibility of the IOData object with the FCHK format.
Parameters
----------
data
The IOData instance to be checked.
"""
if data.mo is not None:
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.")
na = int(np.round(np.sum(data.mo.occsa)))
if not ((data.mo.occsa[:na] == 1.0).all() and (data.mo.occsa[na:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied alpha orbitals "
"followed by fully virtual ones."
)
nb = int(np.round(np.sum(data.mo.occsb)))
if not ((data.mo.occsb[:nb] == 1.0).all() and (data.mo.occsb[nb:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied beta orbitals "
"followed by fully virtual ones."
)


@document_dump_one(
"Gaussian Formatted Checkpoint",
["atnums", "atcorenums"],
Expand Down Expand Up @@ -579,16 +604,8 @@ def dump_one(f: TextIO, data: IOData):
if data.charge is not None:
_dump_integer_scalars("Charge", int(data.charge), f)
if data.mo is not None:
# check occupied orbitals are followed by virtuals
if data.mo.kind == "generalized":
raise ValueError("Cannot dump FCHK because given MO kind is generalized!")
# check integer occupations b/c FCHK assumes these have a specific order.
na = int(np.round(np.sum(data.mo.occsa)))
if not ((data.mo.occsa[:na] == 1.0).all() and (data.mo.occsa[na:] == 0.0).all()):
raise ValueError("Cannot dump FCHK because of fractional alpha occupation numbers.")
nb = int(np.round(np.sum(data.mo.occsb)))
if not ((data.mo.occsb[:nb] == 1.0).all() and (data.mo.occsb[nb:] == 0.0).all()):
raise ValueError("Cannot dump FCHK because of fractional beta occupation numbers.")
# assign number of alpha and beta electrons
multiplicity = abs(na - nb) + 1
_dump_integer_scalars("Multiplicity", multiplicity, f)
Expand Down
20 changes: 17 additions & 3 deletions iodata/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@
from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
from ..periodic import num2sym, sym2num
from ..utils import FileFormatError, FileFormatWarning, LineIterator
from ..utils import FileFormatError, FileFormatWarning, LineIterator, PrepareDumpError

__all__ = []

Expand Down Expand Up @@ -1436,15 +1436,29 @@ def _parse_provenance(
return base_provenance


def prepare_dump(data: IOData):
"""Check the compatibility of the IOData object with QCScheme.
Parameters
----------
data
The IOData instance to be checked.
"""
if "schema_name" not in data.extra:
raise PrepareDumpError("Cannot write qcschema file without 'schema_name' defined.")
schema_name = data.extra["schema_name"]
if schema_name == "qcschema_basis":
raise PrepareDumpError(f"{schema_name} not yet implemented in IOData.")


@document_dump_one(
"QCSchema",
["atnums", "atcoords", "charge", "spinpol"],
["title", "atcorenums", "atmasses", "bonds", "g_rot", "extra"],
)
def dump_one(f: TextIO, data: IOData):
"""Do not edit this docstring. It will be overwritten."""
if "schema_name" not in data.extra:
raise FileFormatError("Cannot write qcschema file without 'schema_name' defined.")
schema_name = data.extra["schema_name"]

if schema_name == "qcschema_molecule":
Expand Down
28 changes: 20 additions & 8 deletions iodata/formats/molden.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ..orbitals import MolecularOrbitals
from ..overlap import compute_overlap, gob_cart_normalization
from ..periodic import num2sym, sym2num
from ..utils import LineIterator, angstrom
from ..utils import LineIterator, PrepareDumpError, angstrom

__all__ = []

Expand Down Expand Up @@ -742,13 +742,27 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold
)


def prepare_dump(data: IOData):
"""Check the compatibility of the IOData object with the Molden format.
Parameters
----------
data
The IOData instance to be checked.
"""
if data.mo is None:
raise PrepareDumpError("The Molden format requires molecular orbitals.")
if data.obasis is None:
raise PrepareDumpError("The Molden format requires an orbital basis set.")
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write Molden file when mo.occs_aminusb is set.")
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write Molden file with generalized orbitals.")


@document_dump_one("Molden", ["atcoords", "atnums", "mo", "obasis"], ["atcorenums", "title"])
def dump_one(f: TextIO, data: IOData):
"""Do not edit this docstring. It will be overwritten."""
# occs_aminusb is not supported
if data.mo.occs_aminusb is not None:
raise ValueError("Cannot write Molden file when mo.occs_aminusb is set.")

# Print the header
f.write("[Molden Format]\n")
if data.title is not None:
Expand All @@ -768,8 +782,6 @@ def dump_one(f: TextIO, data: IOData):
f.write("\n")

# Print the basis set
if data.obasis is None:
raise OSError("A Gaussian orbital basis is required to write a molden file.")
obasis = data.obasis

# Figure out the pure/Cartesian situation. Note that the Molden
Expand Down Expand Up @@ -864,7 +876,7 @@ def dump_one(f: TextIO, data: IOData):
irreps,
)
else:
raise NotImplementedError
raise RuntimeError("This should not happen because of prepare_dump")


def _dump_helper_orb(f, spin, occs, coeffs, energies, irreps):
Expand Down
30 changes: 22 additions & 8 deletions iodata/formats/molekel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
from ..orbitals import MolecularOrbitals
from ..utils import LineIterator, angstrom
from ..utils import LineIterator, PrepareDumpError, angstrom
from .molden import CONVENTIONS, _fix_molden_from_buggy_codes

__all__ = []
Expand Down Expand Up @@ -258,13 +258,27 @@ def load_one(lit: LineIterator, norm_threshold: float = 1e-4) -> dict:
return result


def prepare_dump(data: IOData):
"""Check the compatibility of the IOData object with the Molekel format.
Parameters
----------
data
The IOData instance to be checked.
"""
if data.mo is None:
raise PrepareDumpError("The Molekel format requires molecular orbitals.")
if data.obasis is None:
raise PrepareDumpError("The Molekel format requires an orbital basis set.")
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write Molekel file when mo.occs_aminusb is set.")
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write Molekel file with generalized orbitals.")


@document_dump_one("Molekel", ["atcoords", "atnums", "mo", "obasis"], ["atcharges"])
def dump_one(f: TextIO, data: IOData):
"""Do not edit this docstring. It will be overwritten."""
# occs_aminusb is not supported
if data.mo.occs_aminusb is not None:
raise ValueError("Cannot write Molekel file when mo.occs_aminusb is set.")

# Header
f.write("$MKL\n")
f.write("#\n")
Expand Down Expand Up @@ -339,7 +353,7 @@ def dump_one(f: TextIO, data: IOData):
_dump_helper_occ(f, data, spin="b")

else:
raise ValueError(f"The MKL format does not support {data.mo.kind} orbitals.")
raise RuntimeError("This should not happen because of prepare_dump")


# Defining help dumping functions
Expand All @@ -356,7 +370,7 @@ def _dump_helper_coeffs(f, data, spin=None):
ener = data.mo.energiesb
irreps = data.mo.irreps[norb:] if data.mo.irreps is not None else ["a1g"] * norb
else:
raise OSError("A spin must be specified")
raise ValueError("A spin must be specified")

for j in range(0, norb, 5):
en = " ".join([f" {e: ,.12f}" for e in ener[j : j + 5]])
Expand All @@ -382,7 +396,7 @@ def _dump_helper_occ(f, data, spin=None):
norb = data.mo.norba
occ = data.mo.occs
else:
raise OSError("A spin must be specified")
raise ValueError("A spin must be specified")

for j in range(0, norb, 5):
occs = " ".join([f" {o: ,.7f}" for o in occ[j : j + 5]])
Expand Down
28 changes: 22 additions & 6 deletions iodata/formats/wfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ..orbitals import MolecularOrbitals
from ..overlap import gob_cart_normalization
from ..periodic import num2sym, sym2num
from ..utils import LineIterator
from ..utils import LineIterator, PrepareDumpError

__all__ = []

Expand Down Expand Up @@ -494,23 +494,39 @@ def _dump_helper_section(f: TextIO, data: NDArray, fmt: str, skip: int, step: in
DEFAULT_WFN_TTL = "WFN auto-generated by IOData"


def prepare_dump(data: IOData):
"""Check the compatibility of the IOData object with the WFN format.
Parameters
----------
data
The IOData instance to be checked.
"""
if data.mo is None:
raise PrepareDumpError("The WFN format requires molecular orbitals")
if data.obasis is None:
raise PrepareDumpError("The WFN format requires an orbital basis set")
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write WFN file with generalized orbitals.")
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write WFN file when mo.occs_aminusb is set.")
for shell in data.obasis.shells:
if any(kind != "c" for kind in shell.kinds):
raise PrepareDumpError("The WFN format only supports Cartesian MolecularBasis.")


@document_dump_one(
"WFN",
["atcoords", "atnums", "mo", "obasis"],
["energy", "title", "extra"],
)
def dump_one(f: TextIO, data: IOData) -> None:
"""Do not edit this docstring. It will be overwritten."""
# occs_aminusb is not supported
if data.mo.occs_aminusb is not None:
raise ValueError("Cannot write WFN file when mo.occs_aminusb is set.")
# get shells for the de-contracted basis
shells = []
for shell in data.obasis.shells:
for i, (angmom, kind) in enumerate(zip(shell.angmoms, shell.kinds)):
for exponent, coeff in zip(shell.exponents, shell.coeffs.T[i]):
if kind != "c":
raise ValueError("WFN can be generated only for Cartesian MolecularBasis!")
shells.append(
Shell(
shell.icenter, [angmom], [kind], np.array([exponent]), coeff.reshape(-1, 1)
Expand Down
Loading

0 comments on commit c74af1a

Please sign in to comment.