From e49c43886caeecaa5996d21825df8eda55d4dc8e Mon Sep 17 00:00:00 2001 From: ndaelman-hu <107392603+ndaelman-hu@users.noreply.github.com> Date: Tue, 27 Aug 2024 22:01:34 +0200 Subject: [PATCH] 101 complete basis set migration (#102) - Migrate / Add the following basis sets: - plane waves (no pseudopotentials) - LAPW -- revised structure centered on l-channels (similar to Gulans' uploads) -- automatic type detection (new) -- scaffhold generator (new) - GPW - AtomCentered (just placeholder) - Add testing for all features (new) - Touch up `Mesh` - Add normalization flow decorators to `general.py` --------- Co-authored-by: ndaelman Co-authored-by: nathan --- docs/model_method/basis_sets.md | 144 ++++ .../schema_packages/basis_set.py | 698 ++++++++++++++++++ .../schema_packages/general.py | 38 +- .../schema_packages/model_method.py | 2 +- .../schema_packages/model_system.py | 16 +- .../schema_packages/numerical_settings.py | 87 +-- tests/conftest.py | 63 ++ tests/test_basis_set.py | 420 +++++++++++ 8 files changed, 1417 insertions(+), 51 deletions(-) create mode 100644 docs/model_method/basis_sets.md create mode 100644 src/nomad_simulations/schema_packages/basis_set.py create mode 100644 tests/test_basis_set.py diff --git a/docs/model_method/basis_sets.md b/docs/model_method/basis_sets.md new file mode 100644 index 00000000..4050145d --- /dev/null +++ b/docs/model_method/basis_sets.md @@ -0,0 +1,144 @@ +# Basis Sets + +The following lays down the schema annotation for several families of basis sets. +We start off genercially before running over specific examples. +The aim is not to introduce the full theory behind every basis set, but just enough to understand its main concepts and how they relate. + +## General Structure + +Basis sets are used by codes to represent various kinds of electronic structures, e.g. wavefunctions, densities, exchange densities, etc. +Each electronic structure is therefore described by an individual `BasisSetContainer` in the schema. + +Basis sets may be partitioned by various regions, spanning either physical / reciprocal space, energy (i.e. core vs. valence), or potential / Hamiltonian. +We will cover the partitions per example below. +Each `BasisSetContainer` is then constructed out of several `basis_set_components` matching a single region. +Sometimes, a region is defined in terms of another schema section, e.g. an atom center (`species_scope`) or Hamiltonian terms (`hamiltonian_scope`). + +Note that typically, different kinds of regions also have different mathematical formulations. +Each formulation has its own dedicated section, to facilitate their reuse. +These are all derived from the abstract section `BasisSetComponent`, so that `basis_set_components: list[BasisSetComponent]`. + +Generically, `BasisSetComponent` will allude to the the formula at large and just focus on capturing the _subtype_, as well as relevant _parameters_. +The most relevant ones are those that most commonly listed in the Method section of an article. +These typically also influence the _precision_ most. +Extra, code-specific subtypes and parameters can be added by their respective parsers. + +This then coalesces into the following diagram: + +``` +ModelMethod +└── NumericalSettings[0] +└── ... +└── NumericalSettings[n] = BasisSetContainer + └── BasisSetComponent[1] + └── ... + └── BasisSetComponent[n] + └──> AtomsState + └──> BaseModelMethod +``` + +## Plane-waves + +Plane-wave basis sets start from the description of a free electron and use Fourier to construct the representations of bound electrons. +In reciprocal space, the basis set can thus be thought of as vectors in a Cartesian* grid enclosed within a sphere. + +The main parameter is the spherical radius, i.e. the _cutoff_, which corresponds to the highest frequency representable Fourier frequency. +By convention, the radius is typically expressed in terms of the kinetic energy for the matching free-electron wave. +`PlaneWaveBasisSet` allows storing either `cutoff_radius` and `cutoff_energy`. +It can even derive the former from the latter via normalization. + +### Pseudopotentials + +Under construction... + +## LAPW + +The family of linearized augmented plane-waves is one of the best examples of region partitioning: + +- first it partitions the physical space into regions surrounding the atomic nuclei, i.e. the _muffin-tin spheres_, and the rest, i.e. the _interstitial region_. +- it then further partitions the muffin tins by energy, i.e. core versus valence. +Note that unlike with pseudpotentials, the electrons are not abstracted away here. +They are instead explicitly accounted for and relaxed, just via a different representation. +Hence, LAPW is a _full-electron approach_. + +The interstitial region, covering mostly loose bonding, is described by plane-waves (`APWPlaneWaveBasisSet`). [1] +The valence electrons in the muffin tin (`MuffinTinRegion`), meanwhile, are represented by the spherically symmetric Schrödigner equation. [1] +They follow the additional constraint of having to match the plane-wave description. +In that sense, where the plane-wave description becomes too expensive, it is "augmented" by the muffin-tin description. +This results in a lower plane-wave cutoff. + +The spherically symmetric Schrödigner equation decomposes into an angular and radial part. +In traditional APW (not supported in NOMAD), the angular and radial part are coupled in a non-linear fashion via the radial energy (at the boundary). +All versions of LAPW simplify the coupling by parametrizing this radial energy. [1] + +The representation vector is then developed in terms of the angular basis vectors, i.e. $l$-channels, each with their corresponding radial energy parameter. +This approach is -confusingly- also called _APW_. +It is typically not found standalone, though. +Instead, the linearization introduces a secondary representation via the first-order derivative of the basis vector (function). +Both vectors are typically developed together. +This technique is called linearized APW (LAPW). [1] + +Other formulas have been experimented with too. +For example, the use of even higher-order derivatives, i.e. superlinearized APW (SLAPW). [2, 3] +All of these types are captured by `APWOrbital`, where `type` distinguishes between APW, LAPW, or SLAPW. +The `name` quantity + +Another option is to stay with APW (or LAPW) and add standalone vectors targeting specific atomic states, e.g. high-energy core states, valence states, etc. +These are called _local orbitals_ (lo) and bear other constraints. +Some authors distinguish different vector sums with different kinds of local orbitals, e.g. lo, LO, high-dimensional LO (HDLO). [2, 4] +Since there is no community-wide consensus on the use of these abbreviations, we only utilize `lo` via `APWLocalOrbital`. + +In summary, a generic LAPW basis set can thus be summarized as follows: + +``` +LAPW+lo +├── 1 x plane-wave basis set +└── n x muffin-tin regions + └── l_max x l-channels + ├── orbitals + └── local orbitals ? +``` + +or in terms of the schema: + +``` +BasisSetContainer(name: LAPW+lo) +├── APWPlaneWaveBasisSet +├── MuffinTinRegion(atoms_state: atom A) +├── ... +└── MuffinTinRegion(atoms_state: atom N) + ├── channel 0 + ├── ... + └── channel l_max + ├── APWOrbital(type: lapw) + └── APWLocalOrbital ? +``` + +[1]: D. J. Singh and L. Nordström, \"INTRODUCTION TO THE LAPW METHOD,\" in Planewaves, pseudopotentials, and the LAPW method, 2nd ed. New York, NY: Springer, 2006. + +[2]: A. Gulans, S. Kontur, et al., exciting: a full-potential all-electron package implementing density-functional theory and many-body perturbation theory, _J. Phys.: Condens. Matter_ **26** (363202), 2014. DOI: 10.1088/0953-8984/26/36/363202 + +[3]: J. VandeVondele, M. Krack, et al., WIEN2k: An APW+lo program for calculating the properties of solids, _J. Chem. Phys._ **152**(074101), 2020. DOI: 10.1063/1.5143061 + +[4]: D. Singh and H. Krakauer, H-point phonon in molybdenum: Superlinearized augmented-plane-wave calculations, _Phys. Rev. B_ **43**(1441), 1991. DOI: 10.1103/PhysRevB.43.1441 + +## Gaussian-Planewaves (GPW) + +The CP2K code introduces an algorithm called QuickStep that partitions by Hamiltonian, describing + +- the kinetic and Coulombic electron-nuclei interaction terms of a Gaussian-type orbital (GTO). +- the electronic Hartree energy via plane-waves. + +This GPW choice is to increase performance. [1] +In the schema, we would write: + +``` +BasisSetContainer(name: GPW) +├── PlaneWaveBasisSet(hamiltonian_scope: [`/path/to/kinetic_term/hamiltonian`, `/path/to/e-n_term/hamiltonian`]) +└── AtomCenteredBasisSet(name: GTO, hamiltonian_scope: [`/path/to/hartree_term/hamiltonian`]) +``` + +For further details on the schema, see the CP2K parser documentation. + +[1]: J. VandeVondele, M. Krack, et al., Quickstep: Fast and accurate density functional calculations using a mixed Gaussian and plane waves approach, +_Comp. Phys. Commun._ **167**(2), 103-128, 2005. DOI: 10.1016/j.cpc.2004.12.014. diff --git a/src/nomad_simulations/schema_packages/basis_set.py b/src/nomad_simulations/schema_packages/basis_set.py new file mode 100644 index 00000000..bbd763e0 --- /dev/null +++ b/src/nomad_simulations/schema_packages/basis_set.py @@ -0,0 +1,698 @@ +import itertools +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Optional + +from scipy import constants as const + +if TYPE_CHECKING: + from nomad.datamodel.datamodel import EntryArchive + from structlog.stdlib import BoundLogger + +import numpy as np +import pint +from nomad import utils +from nomad.datamodel.data import ArchiveSection +from nomad.datamodel.metainfo.annotations import ELNAnnotation +from nomad.metainfo import MEnum, Quantity, SubSection +from nomad.units import ureg + +from nomad_simulations.schema_packages.atoms_state import AtomsState +from nomad_simulations.schema_packages.general import ( + check_normalized, + set_not_normalized, +) +from nomad_simulations.schema_packages.model_method import BaseModelMethod +from nomad_simulations.schema_packages.numerical_settings import ( + KMesh, + Mesh, + NumericalSettings, +) + +logger = utils.get_logger(__name__) + + +class BasisSetComponent(ArchiveSection): + """A type section denoting a basis set component of a simulation. + Should be used as a base section for more specialized sections. + Allows for denoting the basis set's _scope_, i.e. to which entity it applies, + e.g. atoms species, orbital type, Hamiltonian term. + + Examples include: + - mesh-based basis sets, e.g. (projector-)(augmented) plane-wave basis sets + - atom-centered basis sets, e.g. Gaussian-type basis sets, Slater-type orbitals, muffin-tin orbitals + """ + + # TODO check implementation of `BasisSetComponent` for Wannier and Slater-Koster orbitals + + name = Quantity( + type=str, + description=""" + Name of the basis set component. + """, + ) + + species_scope = Quantity( + type=AtomsState, + shape=['*'], + description=""" + Reference to the section `AtomsState` specifying the localization of the basis set. + """, + a_eln=ELNAnnotation(components='ReferenceEditQuantity'), + ) + + # TODO: add atom index-based instantiator for species if not present + + hamiltonian_scope = Quantity( + type=BaseModelMethod, + shape=['*'], + description=""" + Reference to the section `BaseModelMethod` containing the information + of the Hamiltonian term to which the basis set applies. + """, + a_eln=ELNAnnotation(components='ReferenceEditQuantity'), + ) + + # ? band_scope or orbital_scope: valence vs core + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + self.name = self.m_def.name + + +class PlaneWaveBasisSet(BasisSetComponent, KMesh): + """ + Basis set over a reciprocal mesh, where each point $k_n$ represents a planar-wave basis function $\frac{1}{\\sqrt{\\omega}} e^{i k_n r}$. + Typically the grid itself is cartesian with only points within a designated sphere considered. + The cutoff radius may be defined by a reciprocal length, or more commonly, the equivalent kinetic energy for a free particle. + + * D. J. Singh and L. Nordström, \"Why Planewaves\" in Planewaves, pseudopotentials, and the LAPW method, 2nd ed. New York, NY: Springer, 2006, pp. 24-26. + """ + + cutoff_energy = Quantity( + type=np.float64, + unit='joule', + description=""" + Cutoff energy for the plane-wave basis set. + The simulation uses plane waves with energies below this cutoff. + """, + ) + + cutoff_radius = Quantity( + type=np.float64, + unit='1/meter', + description=""" + Cutoff radius for the plane-wave basis set. + Is the less frequently used dual to `cutoff_energy`. + """, + ) + + def compute_cutoff_radius( + self, cutoff_energy: Optional[pint.Quantity] + ) -> Optional[pint.Quantity]: + """ + Compute the cutoff radius for the plane-wave basis set, expressed in reciprocal coordinates. + """ + if cutoff_energy is None: + return None + m_e = const.m_e * ureg(const.unit('electron mass')) + h = const.h * ureg(const.unit('Planck constant')) + return np.sqrt(2 * m_e * cutoff_energy) / h + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + + self.label = 'g-mesh' + + if self.cutoff_radius is None: + cutoff_radius = self.compute_cutoff_radius(self.cutoff_energy) + if cutoff_radius is None: + logger.warning( + 'Could not calculate `PlaneWaveBasisSet.cutoff_radius`: missing `cutoff_energy`.' + ) + else: + self.cutoff_radius = cutoff_radius + + +class APWPlaneWaveBasisSet(PlaneWaveBasisSet): + """ + A `PlaneWaveBasisSet` specialized to the APW use case. + Its main descriptors are defined in terms of the `MuffinTin` regions. + """ + + cutoff_fractional = Quantity( + type=np.float64, + shape=[], + description=""" + The spherical cutoff parameter for the interstitial plane waves in the APW family. + This cutoff has no units, referring to the product of the smallest muffin-tin radius + and the length of the cutoff reciprocal vector ($r_{MT} * |K_{cut}|$). + """, + ) + + def compute_cutoff_fractional( + self, cutoff_radius: Optional[pint.Quantity], mt_r_min: Optional[pint.Quantity] + ) -> Optional[pint.Quantity]: + """ + Compute the fractional cutoff parameter for the interstitial plane waves in the LAPW family. + + Args: + - cutoff_radius (Optional[pint.Quantity]): The cutoff radius. + - mt_r_min (Optional[pint.Quantity]): The smallest muffin-tin radius within the `BasisSetContainer`. + """ + reference_unit = 'angstrom' + if cutoff_radius is None or mt_r_min is None: + return None + return cutoff_radius.to(f'1 / {reference_unit}') * mt_r_min.to(reference_unit) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mt_r_min = None + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) # 1st compute `cutoff_radius`` + if self.cutoff_fractional is None: + logger.warning( + 'Expected `APWPlaneWaveBasisSet.cutoff_fractional` to be defined. Will attempt to calculate.' + ) + cutoff_fractional = self.compute_cutoff_fractional( + self.cutoff_radius, self.mt_r_min + ) + if cutoff_fractional is None: + logger.warning( + 'Could not calculate `APWPlaneWaveBasisSet.cutoff_fractional`: missing `cutoff_radius` or `mt_r_min`.' + ) + else: + self.cutoff_fractional = cutoff_fractional + + +class AtomCenteredFunction(ArchiveSection): + """ + Specifies a single function (term) in an atom-centered basis set. + """ + + pass + + # TODO: design system for writing basis functions like gaussian or slater orbitals + + +class AtomCenteredBasisSet(BasisSetComponent): + """ + Defines an atom-centered basis set. + """ + + functional_composition = SubSection( + sub_section=AtomCenteredFunction.m_def, repeats=True + ) # TODO change name + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + # self.name = self.m_def.name + # TODO: set name based on basis functions + # ? use basis set names from Basis Set Exchange + + +class APWBaseOrbital(ArchiveSection): + """ + Abstract base section for (S)(L)APW and local orbital component wavefunctions. + It helps defining the interface with `APWLChannel`. + """ + + n_terms = Quantity( + type=np.int32, + description=""" + Number of terms in the local orbital. + """, + ) + + energy_parameter = Quantity( + type=np.float64, + shape=['n_terms'], + unit='joule', + description=""" + Reference energy parameter for the augmented plane wave (APW) basis set. + Is used to set the energy parameter for each state. + """, + ) # TODO: add approximation formula from energy parameter n + + energy_parameter_n = Quantity( + type=np.int32, + shape=['n_terms'], + description=""" + Reference number of radial nodes for the augmented plane wave (APW) basis set. + This is used to derive the `energy_parameter`. + """, + ) + + energy_status = Quantity( + type=MEnum('fixed', 'pre-optimization', 'post-optimization'), + default='post-optimization', + description=""" + Allow the code to optimize the initial energy parameter. + """, + ) + + differential_order = Quantity( + type=np.int32, + shape=['n_terms'], + description=""" + Derivative order of the radial wavefunction term. + """, + ) # TODO: add check non-negative # ? to remove + + def _get_open_quantities(self) -> set[str]: + """Extract the open quantities of the `APWBaseOrbital`.""" + return { + k for k, v in self.m_def.all_quantities.items() if self.m_get(v) is not None + } + + def _get_lengths(self, quantities: set[str]) -> list[int]: + """Extract the lengths of the `quantities` contained in the set.""" + present_quantities = set(quantities) & self._get_open_quantities() + return [len(getattr(self, quant)) for quant in present_quantities] + + def _of_equal_length(self, lengths: list[int]) -> bool: + """Check if all elements in the list are of equal length.""" + if len(lengths) == 0: + return True + else: + ref_length = lengths[0] + return all(length == ref_length for length in lengths) + + def get_n_terms( + self, + representative_quantities: set[str] = { + 'energy_parameter', + 'energy_parameter_n', + 'differential_order', + }, + ) -> Optional[int]: + """Determine the value of `n_terms` based on the lengths of the representative quantities.""" + lengths = self._get_lengths(representative_quantities) + if not self._of_equal_length(lengths) or len(lengths) == 0: + return None + else: + return lengths[0] + + def _check_non_negative(self, quantity_names: set[str]) -> bool: + """Check if all elements in the set are non-negative.""" + for quantity_name in quantity_names: + if isinstance(quant := self.get(quantity_name), Iterable): + if np.any(np.array(quant) <= 0): + return False + return True + + @check_normalized + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + + # enforce quantity length (will be used for type assignment) + new_n_terms = self.get_n_terms() + if self.n_terms is None: + self.n_terms = new_n_terms + elif self.n_terms != new_n_terms: + logger.error( + f'Inconsistent lengths of `APWBaseOrbital` quantities: {self.m_def.quantities}. Setting back to `None`.' + ) + self.n_terms = None + + # enforce differential order constraints + for quantity_name in ('differential_order', 'energy_parameter_n'): + if self._check_non_negative({quantity_name}): + self.m_set(self.m_def.all_quantities[quantity_name], None) + logger.error( + f'`{self.m_def}.{quantity_name}` must be completely non-negative. Resetting to `None`.' + ) + + # use the differential order as naming convention + self.name = ( + 'APW-like' + if self.differential_order is None or len(self.differential_order) == 0 + else f'{sorted(self.differential_order)}' + ) + + @set_not_normalized + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # it's hard to enforce commutative diagrams between `_determine_apw` and `normalize` + # instead, make all `_determine_apw` soft-coupled and dependent on the normalized state + # leverage normalize ∘ normalize = normalize + + +class APWOrbital(APWBaseOrbital): + """ + Implementation of `APWWavefunction` capturing the foundational (S)(L)APW basis sets, all of the form $\\sum_{lm} \\left[ \\sum_o c_{lmo} \frac{\\partial}{\\partial r}u_l(r, \\epsilon_l) \right] Y_lm$. + The energy parameter $\\epsilon_l$ is always considered fixed during diagonalization, opposed to the original APW formulation. + This representation then has to match the plane-wave $k_n$ points within the muffin-tin sphere. + + Its `name` is showcased as `(s)(l)apw: `. + + * D. J. Singh and L. Nordström, \"INTRODUCTION TO THE LAPW METHOD,\" in Planewaves, pseudopotentials, and the LAPW method, 2nd ed. New York, NY: Springer, 2006, pp. 43-52. + """ + + type = Quantity( + type=MEnum('apw', 'lapw', 'slapw'), # ? add 'spherical_dirac' + description=r""" + Type of augmentation contribution. Abbreviations stand for: + | name | description | radial product | + |------|-------------|----------------| + | APW | augmented plane wave with parametrized energy levels | $A_{lm, k_n} u_l (r, E_l)$ | + | LAPW | linearized augmented plane wave with an optimized energy parameter | $A_{lm, k_n} u_l (r, E_l) + B_{lm, k_n} \dot{u}_{lm} (r, E_l^')$ | + | SLAPW | super linearized augmented plane wave | -- | + + * http://susi.theochem.tuwien.ac.at/lapw/ + """, + ) + + def do_to_type(self, do: Optional[list[int]]) -> Optional[str]: + """ + Set the type of the APW orbital based on the differential order. + """ + if do is None or len(do) == 0: + return None + + do = sorted(do) + if do == [0]: + return 'apw' + elif do == [0, 1]: + return 'lapw' + elif max(do) > 1: # exciting definition + return 'slapw' + else: + return None + + @check_normalized + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + # assign a APW orbital type + # this snippet works of the previous normalization + new_type = self.do_to_type(self.differential_order) + if self.type is None: + self.type = new_type + elif self.type != new_type: + logger.error( + f'Inconsistent `APWOrbital` type: {self.type}. Setting back to `None`.' + ) + self.type = None + + self.name = ( + f'{self.type.upper()}: {self.name}' + if self.type and len(self.differential_order) > 0 + else self.name + ) + + +class APWLocalOrbital(APWBaseOrbital): + """ + Implementation of `APWWavefunction` capturing a local orbital extending a foundational APW basis set. + Local orbitals allow for flexible additions to an `APWOrbital` specification. + They may be included to describe semi-core states, virtual states, ghost bands, or improve overall convergence. + + * D. J. Singh and L. Nordström, \"Role of the Linearization Energies,\" in Planewaves, pseudopotentials, and the LAPW method, 2nd ed. New York, NY: Springer, 2006, pp. 49-52. + """ + + # there's no community consensus on `type` + + @check_normalized + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + self.name = ( + f'LO: {sorted(self.differential_order)}' + if self.differential_order + else 'LO' + ) + + +class APWLChannel(BasisSetComponent): + """ + Collection of all (S)(L)APW and local orbital components that contribute + to a single $l$-channel. $l$ here stands for the angular momentum parameter + in the Laplace spherical harmonics $Y_{l, m}$. + """ + + name = Quantity( + type=np.int32, + description=""" + Angular momentum quantum number of the local orbital. + """, + ) + + n_orbitals = Quantity( + type=np.int32, + description=""" + Number of wavefunctions in the l-channel, i.e. $(2l + 1) n_orbitals$. + """, + ) + + orbitals = SubSection(sub_section=APWBaseOrbital.m_def, repeats=True) + + def _determine_apw(self) -> dict[str, int]: + """ + Produce a count of the APW components in the l-channel. + Invokes `normalize` on `orbitals` to ensure the existence of `type`. + """ + for orb in self.orbitals: + orb.normalize(None, logger) + + type_count = {'apw': 0, 'lapw': 0, 'slapw': 0, 'lo': 0, 'other': 0} + for orb in self.orbitals: + if orb.type is None: + type_count['other'] += 1 + elif isinstance(orb, APWOrbital) and orb.type.lower() in type_count.keys(): + type_count[orb.type] += 1 + elif isinstance(orb, APWLocalOrbital): + type_count['lo'] += 1 + else: + type_count['other'] += 1 # other de facto operates as a catch-all + return type_count + + @set_not_normalized + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @check_normalized + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + # call order: parent of `BasisSetComponent``, then `self` + super(BasisSetComponent, self).normalize(archive, logger) + self.n_orbitals = len(self.orbitals) + + +class MuffinTinRegion(BasisSetComponent, Mesh): + """ + Muffin-tin region around atoms, containing the augmented part of the APW basis set. + The latter is structured by l-channel. Each channel contains a base (S)(L)APW definition, + which may be extended via local orbitals. + """ + + # there are 2 main ways of structuring the APW basis set + # either as APW and lo in the MT region + # or by l-channel in the MT region + + radius = Quantity( + type=np.float64, + unit='meter', + description=""" + The radius descriptor of the `MuffinTin` is spherical shape. + """, + ) + + l_max = Quantity( + type=np.int32, + description=""" + Maximum angular momentum quantum number that is sampled. + Starts at 0. + """, + ) + + l_channels = SubSection(sub_section=APWLChannel.m_def, repeats=True) + + def _determine_apw(self) -> dict[str, int]: + """ + Aggregate the APW component count in the muffin-tin region. + Invokes `normalize` on `l_channels`. + """ + for l_channel in self.l_channels: + l_channel.normalize(None, logger) + + type_count: dict[str, int] = {} + if len(self.l_channels) > 0: + # dynamically determine `type_count` structure + for l_channel in self.l_channels: + type_count.update(l_channel._determine_apw()) + return type_count + + @set_not_normalized + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.mt_r_min = None + + @check_normalized + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + # TODO: add spherical specification, once supported in `Grid` + + +class BasisSetContainer(NumericalSettings): + """ + A section defining the full basis set used for representing the electronic structure + during the diagonalization of a Hamiltonian (component), as defined in `ModelMethod`. + This section may contain multiple basis set specifications under the basis_set_components, + each with their own parameters. + """ + + native_tier = Quantity( + type=str, # to be overwritten by a parser `MEnum` + description=""" + Code-specific tag indicating the overall precision based on the basis set parameters. + The naming conventions of the same code are used. See the parser implementation for the possible values. + The number of tiers varies, but a typical example would be `low`, `medium`, `high`. + """, + ) # TODO: rename to `code_specific_tier` + + # TODO: add reference to `electronic_structure`, + # specifying to which electronic structure representation the basis set is applied + # e.g. wavefunction, density, grids for subroutines, etc. + + basis_set_components = SubSection(sub_section=BasisSetComponent.m_def, repeats=True) + + def _determine_apw(self) -> Optional[str]: + """ + Derive the basis set name for a (S)(L)APW case, including local orbitals. + Invokes `normalize` on `basis_set_components`. + """ + has_plane_wave = ( + True + if any( + isinstance(comp, PlaneWaveBasisSet) + for comp in self.basis_set_components + ) + else False + ) + + type_sums: dict[str, int] = {} + for comp in self.basis_set_components: + if isinstance(comp, MuffinTinRegion): + type_count = comp._determine_apw() + for key in type_count.keys(): + type_sums[key] = type_sums.get(key, 0) + type_count[key] + + type_str = 'APW-like' + for key in ('slapw', 'lapw', 'apw'): + try: + if type_sums[key] > 0: + type_str = key.upper() + if type_sums['lo'] > 0: + type_str += '+lo' + break + except KeyError: + pass + + return type_str if has_plane_wave else None + + def _find_mt_r_min(self) -> Optional[pint.Quantity]: + """ + Scan the container for the smallest muffin-tin region. + """ + mt_r_min = None + for comp in self.basis_set_components: + if isinstance(comp, MuffinTinRegion): + if mt_r_min is None or comp.radius < mt_r_min: + mt_r_min = comp.radius + return mt_r_min + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + + mt_r_min = self._find_mt_r_min() + plane_waves: list[APWPlaneWaveBasisSet] = [] + for component in self.basis_set_components: + if isinstance(component, PlaneWaveBasisSet): + plane_waves.append(component) + elif isinstance(component, MuffinTinRegion): + component.mt_r_min = mt_r_min + component.normalize(archive, logger) + + if len(plane_waves) == 0: + logger.error('Expected a `APWPlaneWaveBasisSet` instance, but found none.') + elif len(plane_waves) > 1: + logger.warning('Multiple plane-wave basis sets found were found.') + self.name = self._determine_apw() + + +def generate_apw( + species: dict[str, dict[str, Any]], + cutoff: Optional[float] = None, +) -> BasisSetContainer: # TODO: extend to cover all parsing use cases (maybe split up?) + """ + Generate a mock APW basis set with the following structure: + . + ├── 1 x plane-wave basis set + └── n x muffin-tin regions + └── l_max x l-channels + ├── orbitals + └── local orbitals + + from a dictionary + { + : { + 'r': , + 'l_max': , + 'orb_do': [[int]], + 'orb_param': [], + 'lo_do': [[int]], + 'lo_param': [], + } + } + """ + + basis_set_components: list[BasisSetComponent] = [] + if cutoff is not None: + pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff) + basis_set_components.append(pw) + + for sp_ref, sp in species.items(): + sp['r'] = sp.get('r', None) + sp['l_max'] = sp.get('l_max', 0) + sp['orb_d_o'] = sp.get('orb_d_o', []) + sp['orb_param'] = sp.get('orb_param', []) + sp['lo_d_o'] = sp.get('lo_d_o', []) + sp['lo_param'] = sp.get('lo_param', []) + + basis_set_components.extend( + [ + MuffinTinRegion( + species_scope=[sp_ref], + radius=sp['r'], + l_max=sp['l_max'], + l_channels=[ + APWLChannel( + name=l_channel, + orbitals=list( + itertools.chain( + ( + APWOrbital( + energy_parameter=param, # TODO: add energy_parameter_n + differential_order=d_o, + ) + for param, d_o in zip( + sp['orb_param'], sp['orb_d_o'] + ) + ), + ( + APWLocalOrbital( + energy_parameter=param, # TODO: add energy_parameter_n + differential_order=d_o, + ) + for param, d_o in zip( + sp['lo_param'], sp['lo_d_o'] + ) + ), + ) + ), + ) + for l_channel in range(sp['l_max'] + 1) + ], + ) + ] + ) + + return BasisSetContainer(basis_set_components=basis_set_components) diff --git a/src/nomad_simulations/schema_packages/general.py b/src/nomad_simulations/schema_packages/general.py index 2e126221..e3d75fa7 100644 --- a/src/nomad_simulations/schema_packages/general.py +++ b/src/nomad_simulations/schema_packages/general.py @@ -18,6 +18,12 @@ from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Callable + + from nomad.datamodel.datamodel import EntryArchive + from structlog.stdlib import BoundLogger + import numpy as np from nomad.config import config from nomad.datamodel.data import Schema @@ -25,10 +31,6 @@ from nomad.datamodel.metainfo.basesections import Activity, Entity from nomad.metainfo import Datetime, Quantity, SchemaPackage, Section, SubSection -if TYPE_CHECKING: - from nomad.datamodel.datamodel import EntryArchive - from structlog.stdlib import BoundLogger - from nomad_simulations.schema_packages.model_method import ModelMethod from nomad_simulations.schema_packages.model_system import ModelSystem from nomad_simulations.schema_packages.outputs import Outputs @@ -44,6 +46,34 @@ m_package = SchemaPackage() +def set_not_normalized(func: 'Callable'): + """ + Decorator to set the section as not normalized. + Typically decorates the section initializer. + """ + + def wrapper(self, *args, **kwargs) -> None: + func(self, *args, **kwargs) + self._is_normalized = False + + return wrapper + + +def check_normalized(func: 'Callable'): + """ + Decorator to check if the section is already normalized. + Typically decorates the section normalizer. + """ + + def wrapper(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + if self._is_normalized: + return None + func(self, archive, logger) + self._is_normalized = True + + return wrapper + + class Program(Entity): """ A base section used to specify a well-defined program used for computation. diff --git a/src/nomad_simulations/schema_packages/model_method.py b/src/nomad_simulations/schema_packages/model_method.py index 96839a9d..9efc2246 100644 --- a/src/nomad_simulations/schema_packages/model_method.py +++ b/src/nomad_simulations/schema_packages/model_method.py @@ -838,7 +838,7 @@ class ExcitedStateMethodology(ModelMethodElectronic): A base section used to define the parameters typical of excited-state calculations. "ExcitedStateMethodology" mainly refers to methodologies which consider many-body effects as a perturbation of the original DFT Hamiltonian. These are: GW, TDDFT, BSE. - """ + """ # Note: we don't really talk about Hamiltonians in DFT: their physics is accommodated in the functional itself n_states = Quantity( type=np.int32, diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 70c26649..01ebb30e 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -151,12 +151,20 @@ class GeometricSpace(Entity): ) coordinates_system = Quantity( - type=MEnum('cartesian', 'polar', 'cylindrical', 'spherical'), + type=MEnum('cartesian', 'cylindrical', 'spherical', 'ellipsoidal', 'polar'), default='cartesian', description=""" - Coordinate system used to determine the geometrical information of a shape in real - space. Default to 'cartesian'. - """, + Coordinate system used to define geometrical primitives of a shape in real + space. Defaults to 'cartesian'. + + | name | description | dimensionalities | coordinates | + |------------|-------------|------------------|-------------| + | cartesian | coordinate system with fixed angles between the axes (not necessarily 90°) | 1, 2, 3 | x, y, z | + | cylindrical| cylindrical symmetry | 3 | r, theta, z | + | spherical | spherical symmetry | 3 | r, theta, phi | + | ellipsoidal| spherically elongated system | 3 | r, theta, phi | + | polar | spherical symmetry | 2 | r, theta | + """, # ? could this not be extended to the k-space ) origin_shift = Quantity( diff --git a/src/nomad_simulations/schema_packages/numerical_settings.py b/src/nomad_simulations/schema_packages/numerical_settings.py index 71d3eb27..1e69a892 100644 --- a/src/nomad_simulations/schema_packages/numerical_settings.py +++ b/src/nomad_simulations/schema_packages/numerical_settings.py @@ -54,15 +54,29 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) +class Smearing(NumericalSettings): + """ + Section specifying the smearing of the occupation numbers to + either simulate temperature effects or improve SCF convergence. + """ + + name = Quantity( + type=MEnum('Fermi-Dirac', 'Gaussian', 'Methfessel-Paxton'), + description=""" + Smearing routine employed. + """, + ) + + class Mesh(ArchiveSection): """ - A base section used to specify the settings of a sampling mesh. It supports uniformly-spaced - meshes and symmetry-reduced representations. + A base section used to specify the settings of a sampling mesh. + It supports uniformly-spaced meshes and symmetry-reduced representations. """ spacing = Quantity( type=MEnum('Equidistant', 'Logarithmic', 'Tan'), - default='Equidistant', + shape=['dimensionality'], description=""" Identifier for the spacing of the Mesh. Defaults to 'Equidistant' if not defined. It can take the values: @@ -74,19 +88,6 @@ class Mesh(ArchiveSection): """, ) - center = Quantity( - type=MEnum('Gamma-centered', 'Monkhorst-Pack', 'Gamma-offcenter'), - description=""" - Identifier for the center of the Mesh: - - | Name | Description | - | --------- | -------------------------------- | - | `'Gamma-centered'` | Regular mesh is centered around Gamma. No offset. | - | `'Monkhorst-Pack'` | Regular mesh with an offset of half the reciprocal lattice vector. | - | `'Gamma-offcenter'` | Regular mesh with an offset that is neither `'Gamma-centered'`, nor `'Monkhorst-Pack'`. | - """, - ) - quadrature = Quantity( type=MEnum( 'Gauss-Legendre', @@ -105,7 +106,7 @@ class Mesh(ArchiveSection): | `'Clenshaw-Curtis'` | Quadrature rule for integration using Chebyshev polynomials using discrete cosine transformations | | `'Gauss-Hermite'` | Quadrature rule for integration using Hermite polynomials | """, - ) + ) # ! @JosePizarro3 I think that this is separate from the spacing n_points = Quantity( type=np.int32, @@ -118,7 +119,7 @@ class Mesh(ArchiveSection): type=np.int32, default=3, description=""" - Dimensionality of the mesh: 1, 2, or 3. If not defined, it is assumed to be 3. + Dimensionality of the mesh: 1, 2, or 3. Defaults to 3. """, ) @@ -126,9 +127,9 @@ class Mesh(ArchiveSection): type=np.int32, shape=['dimensionality'], description=""" - Amount of mesh point sampling along each axis, i.e. [nx, ny, nz]. + Amount of mesh point sampling along each axis. See `type` for the axes definition. """, - ) + ) # ? @JosePizzaro3: should the mesh also contain its boundary information points = Quantity( type=np.complex128, @@ -143,9 +144,9 @@ class Mesh(ArchiveSection): shape=['n_points'], description=""" The amount of times the same point reappears. A value larger than 1, typically indicates - a symmtery operation that was applied to the `Mesh`. This quantity is equivalent to `weights`: + a symmetry operation that was applied to the `Mesh`. This quantity is equivalent to `weights`: - multiplicities = 1 / weights + multiplicities = n_points * weights """, ) @@ -153,10 +154,10 @@ class Mesh(ArchiveSection): type=np.float64, shape=['n_points'], description=""" - Weight of each point. A value smaller than 1, typically indicates a symmtery operation that was + Weight of each point. A value smaller than 1, typically indicates a symmetry operation that was applied to the mesh. This quantity is equivalent to `multiplicities`: - weights = 1 / multiplicities + weights = multiplicities / n_points """, ) @@ -309,11 +310,27 @@ class KMesh(Mesh): """ label = Quantity( - type=MEnum('k-mesh', 'q-mesh'), + type=MEnum('k-mesh', 'g-mesh', 'q-mesh'), default='k-mesh', description=""" - Label used to identify the `KMesh` with the reciprocal vector used. In linear response, `k` is used for - refering to the wave-vector of electrons, while `q` is used for the scattering effect of the Coulomb potential. + Label used to identify the meaning of the reciprocal grid. + The actual meaning of `k` vs `g` vs `q` is context-dependent, though typically: + - `g` is used for the primitive vectors (typically within the Brillouin zone). + - `k` for a generic reciprocal vector. + - `q` for any momentum change imparted by a scattering event. + """, + ) + + center = Quantity( + type=MEnum('Gamma-centered', 'Monkhorst-Pack', 'Gamma-offcenter'), + description=""" + Identifier for the center of the Mesh: + + | Name | Description | + | --------- | -------------------------------- | + | `'Gamma-centered'` | Regular mesh is centered around Gamma. No offset. | + | `'Monkhorst-Pack'` | Regular mesh with an offset of half the reciprocal lattice vector. | + | `'Gamma-offcenter'` | Regular mesh with an offset that is neither `'Gamma-centered'`, nor `'Monkhorst-Pack'`. | """, ) @@ -841,7 +858,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: class SelfConsistency(NumericalSettings): """ A base section used to define the convergence settings of self-consistent field (SCF) calculation. - It determines the condictions for `is_scf_converged` in `SCFOutputs` (see outputs.py). The convergence + It determines the conditions for `is_scf_converged` in `SCFOutputs` (see outputs.py). The convergence criteria covered are: 1. The number of iterations is smaller than or equal to `n_max_iterations`. @@ -888,17 +905,3 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) - - -class BasisSet(NumericalSettings): - """""" - - # TODO work on this base section (@ndaelman-hu) - - def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): - super().__init__(m_def, m_context, **kwargs) - # Set the name of the section - self.name = self.m_def.name - - def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: - super().normalize(archive, logger) diff --git a/tests/conftest.py b/tests/conftest.py index 097298c2..02c25e41 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -403,3 +403,66 @@ def k_line_path() -> KLinePathSettings: @pytest.fixture(scope='session') def k_space_simulation() -> Simulation: return generate_k_space_simulation() + + +refs_apw = [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer', + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer', + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', + 'cutoff_energy': 500.0, + }, + ], + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer', + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWPlaneWaveBasisSet', + 'cutoff_energy': 500.0, + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.MuffinTinRegion', + 'species_scope': ['/data/model_system/0/cell/0/atoms_state/0'], + 'radius': 1.0, + 'l_max': 2, + 'l_channels': [ + { + 'name': 0, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'energy_parameter': [0.0], + 'differential_order': [0], + }, + ], + }, + { + 'name': 1, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'energy_parameter': [0.0], + 'differential_order': [0], + }, + ], + }, + { + 'name': 2, + 'orbitals': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.APWOrbital', + 'energy_parameter': [0.0], + 'differential_order': [0], + }, + ], + }, + ], + }, + ], + }, +] diff --git a/tests/test_basis_set.py b/tests/test_basis_set.py new file mode 100644 index 00000000..b1dcb03c --- /dev/null +++ b/tests/test_basis_set.py @@ -0,0 +1,420 @@ +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + import pint + +import numpy as np +import pytest +from nomad.datamodel.datamodel import EntryArchive +from nomad.units import ureg + +from nomad_simulations.schema_packages.atoms_state import AtomsState +from nomad_simulations.schema_packages.basis_set import ( + APWBaseOrbital, + APWLocalOrbital, + APWOrbital, + APWPlaneWaveBasisSet, + AtomCenteredBasisSet, + BasisSetContainer, + MuffinTinRegion, + PlaneWaveBasisSet, + generate_apw, +) +from nomad_simulations.schema_packages.general import Simulation +from nomad_simulations.schema_packages.model_method import BaseModelMethod, ModelMethod +from nomad_simulations.schema_packages.model_system import AtomicCell, ModelSystem +from tests.conftest import refs_apw + +from . import logger + + +@pytest.mark.parametrize( + 'ref_cutoff_radius, cutoff_energy', + [ + (None, None), + (1.823 / ureg.angstrom, 500 * ureg.eV), # reference computed by ChatGPT 4o + ], +) +def test_cutoff( + ref_cutoff_radius: 'pint.Quantity', cutoff_energy: 'pint.Quantity' +) -> None: + """Test the quantitative results when computing certain plane-wave cutoffs.""" + pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff_energy) + cutoff_radius = pw.compute_cutoff_radius(cutoff_energy) + + if cutoff_radius is None: + assert cutoff_radius is ref_cutoff_radius + else: + assert np.isclose( + cutoff_radius.to(ref_cutoff_radius.units).magnitude, + ref_cutoff_radius.magnitude, + atol=1e-3, + ) + + +@pytest.mark.parametrize( + 'mts, ref_mt_r_min', + [ + ([], None), + ([None], None), + ([MuffinTinRegion(radius=1.0 * ureg.angstrom)], 1.0), + ([MuffinTinRegion(radius=r * ureg.angstrom) for r in (1.0, 2.0, 3.0)], 1.0), + ], +) +def test_mt_r_min(mts: list[Optional[MuffinTinRegion]], ref_mt_r_min: float) -> None: + """ + Test the computation of the minimum muffin-tin radius. + """ + bs = BasisSetContainer(basis_set_components=mts) + mt_r_min = bs._find_mt_r_min() + + try: + assert mt_r_min.to('angstrom').magnitude == ref_mt_r_min + except AttributeError: + assert mt_r_min is ref_mt_r_min + + bs.basis_set_components.append(APWPlaneWaveBasisSet(cutoff_energy=500 * ureg('eV'))) + bs.normalize(None, logger) + + try: + assert ( + bs.basis_set_components[-2].mt_r_min.to('angstrom').magnitude + == ref_mt_r_min + ) + except (IndexError, AttributeError): + assert ref_mt_r_min is None + + +@pytest.mark.parametrize( + 'ref_cutoff_fractional, cutoff_energy, mt_radius', + [ + (None, None, None), + (None, 500.0 * ureg.eV, None), + (None, None, 1.0), + (1.823, 500.0 * ureg.eV, 1.0 * ureg.angstrom), + ], +) +def test_cutoff_failure( + ref_cutoff_fractional: float, + cutoff_energy: 'pint.Quantity', + mt_radius: 'pint.Quantity', +) -> None: + """Test modes where `cutoff_fractional` is not computed.""" + pw = APWPlaneWaveBasisSet(cutoff_energy=cutoff_energy if cutoff_energy else None) + if mt_radius is not None: + pw.cutoff_fractional = pw.compute_cutoff_fractional( + pw.compute_cutoff_radius(cutoff_energy), mt_radius + ) + + if ref_cutoff_fractional is None: + assert pw.cutoff_fractional is None + else: + assert np.isclose(pw.cutoff_fractional, ref_cutoff_fractional, atol=1e-3) + + +@pytest.mark.parametrize( + 'ref_index, species_def, cutoff', + [ + (0, {}, None), + (1, {}, 500.0), + ( + 2, + { + '/data/model_system/0/cell/0/atoms_state/0': { + 'r': 1, + 'l_max': 2, + 'orb_d_o': [[0]], + 'orb_param': [[0.0]], + } + }, + 500.0, + ), + ], +) +def test_full_apw( + ref_index: int, species_def: dict[str, dict[str, Any]], cutoff: Optional[float] +) -> None: + """Test the composite structure of APW basis sets.""" + entry = EntryArchive( + data=Simulation( + model_system=[ + ModelSystem( + cell=[AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')])] + ) + ], + model_method=[ModelMethod(numerical_settings=[])], + ) + ) + + numerical_settings = entry.data.model_method[0].numerical_settings + numerical_settings.append(generate_apw(species_def, cutoff=cutoff)) + + # test structure + assert numerical_settings[0].m_to_dict() == refs_apw[ref_index] + + +@pytest.mark.parametrize( + 'ref_n_terms, e, d_o', + [ + (None, None, None), # unset + (0, [], []), # empty + (None, [0.0], []), # logically inconsistent + (1, [0.0], [0]), # apw + (2, 2 * [0.0], [0, 1]), # lapw + ], +) +def test_apw_base_orbital(ref_n_terms: Optional[int], e: list[float], d_o: list[int]): + orb = APWBaseOrbital(energy_parameter=e, differential_order=d_o) + assert orb.get_n_terms() == ref_n_terms + + +@pytest.mark.parametrize('n_terms, ref_n_terms', [(None, 1), (1, 1), (2, None)]) +def test_apw_base_orbital_normalize( + n_terms: Optional[int], ref_n_terms: Optional[int] +) -> None: + orb = APWBaseOrbital( + n_terms=n_terms, + energy_parameter=[0], + differential_order=[1], + ) + orb.normalize(None, logger) + assert orb.n_terms == ref_n_terms + + +@pytest.mark.parametrize( + 'ref_type, do', + [ + (None, None), + (None, []), + (None, [0, 0, 1]), + ('apw', [0]), + ('lapw', [0, 1]), + ('slapw', [0, 2]), + ], +) +def test_apw_orbital(ref_type: Optional[str], do: Optional[int]) -> None: + orb = APWOrbital(differential_order=do) + assert orb.do_to_type(orb.differential_order) == ref_type + + +# ? necessary +@pytest.mark.parametrize( + 'ref_n_terms, e, d_o', + [ + (None, [0.0], []), + (1, [0.0], [0]), + (2, 2 * [0.0], [0, 1]), + (3, 3 * [0.0], [0, 1, 0]), + ], +) +def test_apw_local_orbital( + ref_n_terms: Optional[int], + e: list[float], + d_o: list[int], +) -> None: + orb = APWLocalOrbital( + energy_parameter=e, + differential_order=d_o, + ) + assert orb.get_n_terms() == ref_n_terms + + +@pytest.mark.parametrize( + 'ref_type, ref_mt_counts, ref_l_counts, species_def, cutoff', + [ + ( + None, + [[0, 0, 0, 0, 0]], + [[[0, 0, 0, 0, 0]]], + { + 'H': { + 'r': 1.0, + 'l_max': 0, + 'orb_d_o': [], + 'orb_param': [], + 'lo_d_o': [], + 'lo_param': [], + } + }, + None, + ), + ( + None, + [[1, 0, 0, 0, 0]], + [[[1, 0, 0, 0, 0]]], + { + 'H': { + 'r': 1.0, + 'l_max': 0, + 'orb_d_o': [[0]], + 'orb_param': [[0.0]], + 'lo_d_o': [], + 'lo_param': [], + } + }, + None, + ), + ( + 'APW-like', + [[0, 0, 0, 0, 1]], + [[[0, 0, 0, 0, 1]]], + { + 'H': { + 'r': 1.0, + 'l_max': 0, + 'orb_d_o': [[]], + 'orb_param': [[]], + 'lo_d_o': [], + 'lo_param': [], + } + }, + 500.0, + ), + ( + 'APW', + [[1, 0, 0, 0, 0]], + [[[1, 0, 0, 0, 0]]], + { + 'H': { + 'r': 1.0, + 'l_max': 1, + 'orb_d_o': [[0]], + 'orb_param': [[0.0]], + 'lo_d_o': [], + 'lo_param': [], + } + }, + 500.0, + ), + ( + 'LAPW', + [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0]], + [[[1, 0, 0, 0, 0]], [[0, 1, 0, 0, 0]]], + { + 'H': { + 'r': 1.0, + 'l_max': 0, + 'orb_d_o': [[0]], + 'orb_param': [[0.0]], + 'lo_d_o': [], + 'lo_param': [], + }, + 'O': { + 'r': 2.0, + 'l_max': 0, + 'orb_d_o': [[0, 1]], + 'orb_param': [2 * [0.0]], + 'lo_d_o': [], + 'lo_param': [], + }, + }, + 500.0, + ), + ( + 'SLAPW', + [[1, 0, 0, 0, 0], [0, 1, 1, 0, 0]], + [[[1, 0, 0, 0, 0]], [[0, 1, 1, 0, 0]]], + { + 'H': { + 'r': 1.0, + 'l_max': 0, + 'orb_d_o': [[0]], + 'orb_param': [[0.0]], + 'lo_d_o': [], + 'lo_param': [], + }, + 'O': { + 'r': 2.0, + 'l_max': 2, + 'orb_d_o': [[0, 1], [0, 2]], + 'orb_param': 2 * [2 * [0.0]], + 'lo_d_o': [], + 'lo_param': [], + }, + }, + 500.0, + ), + ], +) +def test_determine_apw( + ref_type: str, + ref_mt_counts: list[list[int]], + ref_l_counts: list[list[list[int]]], + species_def: dict[str, dict[str, Any]], + cutoff: Optional[float], +) -> None: + """Test the L-channel APW structure.""" + ref_keys = ('apw', 'lapw', 'slapw', 'lo', 'other') + bs = generate_apw(species_def, cutoff=cutoff) + + # test from the bottom up + for bsc in bs.basis_set_components: + if isinstance(bsc, MuffinTinRegion): + l_counts = ref_l_counts.pop(0) + for l_channel in bsc.l_channels: + try: + assert l_channel._determine_apw() == dict( + zip(ref_keys, l_counts.pop(0)) + ) + except IndexError: + pass + try: + assert bsc._determine_apw() == dict(zip(ref_keys, ref_mt_counts.pop(0))) + except IndexError: + pass + assert bs._determine_apw() == ref_type + + +def test_quick_step() -> None: + """Test the feasibility of describing a QuickStep basis set.""" + entry = EntryArchive( + data=Simulation( + model_method=[ + ModelMethod( + contributions=[ + BaseModelMethod(name='kinetic'), + BaseModelMethod(name='electron-ion'), + BaseModelMethod(name='hartree'), + ], + numerical_settings=[], + ) + ], + ) + ) + numerical_settings = entry.data.model_method[0].numerical_settings + numerical_settings.append( + BasisSetContainer( + # scope='density', + basis_set_components=[ + AtomCenteredBasisSet( + hamiltonian_scope=[ + entry.data.model_method[0].contributions[0], + entry.data.model_method[0].contributions[1], + ], + ), + PlaneWaveBasisSet( + cutoff_energy=500 * ureg.eV, + hamiltonian_scope=[entry.data.model_method[0].contributions[2]], + ), + ], + ) + ) + + assert numerical_settings[0].m_to_dict() == { + 'm_def': 'nomad_simulations.schema_packages.basis_set.BasisSetContainer', + 'basis_set_components': [ + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.AtomCenteredBasisSet', + 'hamiltonian_scope': [ + '/data/model_method/0/contributions/0', + '/data/model_method/0/contributions/1', + ], + }, + { + 'm_def': 'nomad_simulations.schema_packages.basis_set.PlaneWaveBasisSet', + 'hamiltonian_scope': ['/data/model_method/0/contributions/2'], + 'cutoff_energy': (500.0 * ureg.eV).to('joule').magnitude, + }, + ], + } + # TODO: generate a QuickStep generator in the CP2K plugin