diff --git a/pyproject.toml b/pyproject.toml index da497aba..ae813b12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,11 @@ maintainers = [ { name = "Joseph F. Rudzinski", email = "joseph.rudzinski@physik.hu-berlin.de" } ] license = { file = "LICENSE" } +# dependencies = [ +# "nomad-lab>=1.3.0", +# "matid>=2.0.0.dev2", +# "nomad-simulations@file:///home/bmohr/software/nomad-simulations", +# ] dependencies = [ "nomad-lab>=1.3.0", "matid>=2.0.0.dev2", @@ -40,12 +45,13 @@ dependencies = [ [project.optional-dependencies] dev = [ - "mypy==1.0.1", - "ruff", - "pytest", - "pytest-timeout", - "pytest-cov", - "structlog", + 'mypy==1.0.1', + 'pytest>= 5.3.0, <8', + 'pytest-timeout>=1.4.2', + 'pytest-cov>=2.7.1', + 'ruff>=0.6', + 'structlog>=1.0', + 'typing-extensions>=4.12', ] [tool.uv] diff --git a/src/nomad_simulations/schema_packages/__init__.py b/src/nomad_simulations/schema_packages/__init__.py index 8b730793..78d66557 100644 --- a/src/nomad_simulations/schema_packages/__init__.py +++ b/src/nomad_simulations/schema_packages/__init__.py @@ -31,8 +31,8 @@ class NOMADSimulationsEntryPoint(SchemaPackageEntryPoint): description='Limite of the number of atoms in the unit cell to be treated for the system type classification from MatID to work. This is done to avoid overhead of the package.', ) equal_cell_positions_tolerance: float = Field( - 1e-12, - description='Tolerance (in meters) for the cell positions to be considered equal.', + 12, + description='Decimal order or tolerance (in meters) for comparing cell positions.', ) def load(self): diff --git a/src/nomad_simulations/schema_packages/atoms_state.py b/src/nomad_simulations/schema_packages/atoms_state.py index 32fbdd11..a43a2f5a 100644 --- a/src/nomad_simulations/schema_packages/atoms_state.py +++ b/src/nomad_simulations/schema_packages/atoms_state.py @@ -552,7 +552,16 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: ) -class AtomsState(Entity): +class State(Entity): + """ + A base section to define the state information of the system. + """ + + def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): + super().__init__(m_def, m_context, **kwargs) + + +class AtomsState(State): """ A base section to define each atom state information. """ diff --git a/src/nomad_simulations/schema_packages/general.py b/src/nomad_simulations/schema_packages/general.py index 9a2d48f0..2264e765 100644 --- a/src/nomad_simulations/schema_packages/general.py +++ b/src/nomad_simulations/schema_packages/general.py @@ -1,10 +1,17 @@ -from typing import TYPE_CHECKING +#! TODO: Why is TYPE_CHECKING False? +from typing import TYPE_CHECKING, List, Iterable, Union -if TYPE_CHECKING: - from collections.abc import Callable - - from nomad.datamodel.datamodel import EntryArchive - from structlog.stdlib import BoundLogger +if not TYPE_CHECKING: + from nomad.datamodel.datamodel import ( + EntryArchive, + ) + from nomad.metainfo import ( + Context, + Section, + ) + from structlog.stdlib import ( + BoundLogger, + ) import numpy as np from nomad.config import config @@ -217,7 +224,10 @@ def _set_system_branch_depth( system_parent=system_child, branch_depth=branch_depth + 1 ) - def resolve_composition_formula(self, system_parent: ModelSystem) -> None: + #! Generalize from checks for atomic systems, error with CG input + def resolve_composition_formula( + self, system_parent: ModelSystem, logger: 'BoundLogger' + ) -> None: """Determine and set the composition formula for `system_parent` and all of its descendants. @@ -226,7 +236,7 @@ def resolve_composition_formula(self, system_parent: ModelSystem) -> None: """ def set_composition_formula( - system: ModelSystem, subsystems: list[ModelSystem], atom_labels: list[str] + system: ModelSystem, subsystems: list[ModelSystem], labels: list[str] ) -> None: """Determine the composition formula for `system` based on its `subsystems`. If `system` has no children, the atom_labels are used to determine the formula. @@ -238,13 +248,15 @@ def set_composition_formula( to the atom indices stored in system. """ if not subsystems: - atom_indices = ( - system.atom_indices if system.atom_indices is not None else [] + particle_indices = ( + system.particle_indices + if system.particle_indices is not None + else [] ) subsystem_labels = ( - [np.array(atom_labels)[atom_indices]] - if atom_labels - else ['Unknown' for atom in range(len(atom_indices))] + [np.array(labels)[particle_indices]] + if labels + else ['Unknown' for atom in range(len(particle_indices))] ) else: subsystem_labels = [ @@ -258,7 +270,7 @@ def set_composition_formula( children_names=subsystem_labels ) - def get_composition_recurs(system: ModelSystem, atom_labels: list[str]) -> None: + def get_composition_recurs(system: ModelSystem, labels: list[str]) -> None: """Traverse the system hierarchy downward and set the branch composition for all (sub)systems at each level. @@ -268,22 +280,17 @@ def get_composition_recurs(system: ModelSystem, atom_labels: list[str]) -> None: to the atom indices stored in system. """ subsystems = system.model_system - set_composition_formula( - system=system, subsystems=subsystems, atom_labels=atom_labels - ) + set_composition_formula(system=system, subsystems=subsystems, labels=labels) if subsystems: for subsystem in subsystems: - get_composition_recurs(system=subsystem, atom_labels=atom_labels) - - atoms_state = ( - system_parent.cell[0].atoms_state if system_parent.cell is not None else [] - ) - atom_labels = ( - [atom.chemical_symbol for atom in atoms_state] - if atoms_state is not None - else [] - ) - get_composition_recurs(system=system_parent, atom_labels=atom_labels) + get_composition_recurs(system=subsystem, labels=labels) + + # ! CG: system_parent.cell[0].particles_state instead of atoms_state! + labels = [] + if system_parent.cell is not None: + labels = system_parent.cell[0].get('labels', logger=logger) + + get_composition_recurs(system=system_parent, labels=labels) def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super(Schema, self).normalize(archive, logger) @@ -308,7 +315,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: if is_not_representative(model_system=system_parent, logger=logger): continue - self.resolve_composition_formula(system_parent=system_parent) + self.resolve_composition_formula(system_parent=system_parent, logger=logger) m_package.__init_metainfo__() diff --git a/src/nomad_simulations/schema_packages/model_method.py b/src/nomad_simulations/schema_packages/model_method.py index c7b143ff..7e6cd19d 100644 --- a/src/nomad_simulations/schema_packages/model_method.py +++ b/src/nomad_simulations/schema_packages/model_method.py @@ -544,7 +544,7 @@ def resolve_orbital_references( # If the child is not an "active_atom", the normalization will not run if active_atom.type != 'active_atom': continue - indices = active_atom.atom_indices + indices = active_atom.particle_indices for index in indices: try: active_atoms_state = atoms_state[index] diff --git a/src/nomad_simulations/schema_packages/model_system.py b/src/nomad_simulations/schema_packages/model_system.py index 0555c432..eecc1afb 100644 --- a/src/nomad_simulations/schema_packages/model_system.py +++ b/src/nomad_simulations/schema_packages/model_system.py @@ -1,7 +1,28 @@ +# +# Copyright The NOMAD Authors. +# +# This file is part of NOMAD. See https://nomad-lab.eu for further info. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import re +import sys +from functools import lru_cache +from hashlib import sha1 from typing import TYPE_CHECKING, Optional import ase +from ase.symbols import symbols2numbers import numpy as np from matid import Classifier, SymmetryAnalyzer # pylint: disable=import-error from matid.classification.classifications import ( @@ -22,12 +43,18 @@ from nomad.units import ureg if TYPE_CHECKING: + from collections.abc import Generator + from typing import Any, Callable, Optional + + import pint from nomad.datamodel.datamodel import EntryArchive from nomad.metainfo import Context, Section from structlog.stdlib import BoundLogger from nomad_simulations.schema_packages.atoms_state import AtomsState +from nomad_simulations.schema_packages.particles_state import ParticlesState from nomad_simulations.schema_packages.utils import ( + catch_not_implemented, get_sibling_section, is_not_representative, ) @@ -172,6 +199,7 @@ class GeometricSpace(Entity): """, ) + # TODO: Either generalize this or add logic for different cell types def get_geometric_space_for_atomic_cell(self, logger: 'BoundLogger') -> None: """ Get the real space parameters for the atomic cell using ASE. @@ -200,6 +228,72 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: return +def _check_implemented(func: 'Callable'): + """ + Decorator to restrict the comparison functions to the same class. + """ + + def wrapper(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return func(self, other) + + return wrapper + + +class PartialOrderElement: + def __init__(self, representative_variable): + self.representative_variable = representative_variable + + def __hash__(self): + return self.representative_variable.__hash__() + + @_check_implemented + def __eq__(self, other): + return self.representative_variable == other.representative_variable + + @_check_implemented + def __lt__(self, other): + return False + + @_check_implemented + def __gt__(self, other): + return False + + def __le__(self, other): + return self.__eq__(other) + + def __ge__(self, other): + return self.__eq__(other) + + # __ne__ assumes that usage in a finite set with its comparison definitions + + +class HashedPositions(PartialOrderElement): + # `representative_variable` is a `pint.Quantity` object + + def __hash__(self): + hash_str = sha1( + np.ascontiguousarray( + np.round( + self.representative_variable.to_base_units().magnitude, + decimals=configuration.equal_cell_positions_tolerance, + out=None, + ) + ).tobytes() + ).hexdigest() + return int(hash_str, 16) + + def __eq__(self, other): + """Equality as defined between HashedPositions.""" + if ( + self.representative_variable is None + or other.representative_variable is None + ): + return NotImplemented + return np.allclose(self.representative_variable, other.representative_variable) + + class Cell(GeometricSpace): """ A base section used to specify the cell quantities of a system at a given moment in time. @@ -213,16 +307,19 @@ class Cell(GeometricSpace): """, ) + # TODO: default "unavailable"? type = Quantity( type=MEnum('original', 'primitive', 'conventional'), description=""" Representation type of the cell structure. It might be: - - 'original' as in origanally parsed, + - 'original' as in originally parsed, - 'primitive' as the primitive unit cell, - 'conventional' as the conventional cell used for referencing. """, ) + # ? What does this mean? Number of particles in the cell? + # TODO: improve description n_cell_points = Quantity( type=np.int32, description=""" @@ -235,7 +332,7 @@ class Cell(GeometricSpace): shape=['n_cell_points', 3], unit='meter', description=""" - Positions of all the atoms in Cartesian coordinates. + Positions of all the particles in Cartesian coordinates. """, ) @@ -244,8 +341,8 @@ class Cell(GeometricSpace): shape=['n_cell_points', 3], unit='meter / second', description=""" - Velocities of the atoms. It is the change in cartesian coordinates of the atom position - with time. + Velocities of the particles. It is the change in cartesian coordinates of the + particle position with time. """, ) @@ -254,8 +351,9 @@ class Cell(GeometricSpace): shape=[3, 3], unit='meter', description=""" - Lattice vectors of the simulated cell in Cartesian coordinates. The first index runs - over each lattice vector. The second index runs over the $x, y, z$ Cartesian coordinates. + Lattice vectors of the simulated cell in Cartesian coordinates. The first index + runs over each lattice vector. The second index runs over the $x, y, z$ + Cartesian coordinates. """, ) @@ -278,45 +376,161 @@ class Cell(GeometricSpace): """, ) - def _check_positions(self, positions_1, positions_2) -> list: - # Check that all the `positions`` of `cell_1` match with the ones in `cell_2` - check_positions = [] - for i1, pos1 in enumerate(positions_1): - for i2, pos2 in enumerate(positions_2): - if np.allclose( - pos1, pos2, atol=configuration.equal_cell_positions_tolerance - ): - check_positions.append([i1, i2]) - break - return check_positions - - def is_equal_cell(self, other) -> bool: + def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): + super().__init__(m_def, m_context, **kwargs) + self.logger = None # Initialize logger attribute + + @staticmethod + def _generate_comparer(obj: 'Cell') -> 'Generator[Any, None, None]': + try: + return ((HashedPositions(pos)) for pos in obj.positions) + except AttributeError: + raise NotImplementedError + + @catch_not_implemented + def is_lt_cell(self, other) -> bool: + return set(self._generate_comparer(self)) < set(self._generate_comparer(other)) + + @catch_not_implemented + def is_gt_cell(self, other) -> bool: + return set(self._generate_comparer(self)) > set(self._generate_comparer(other)) + + @catch_not_implemented + def is_le_cell(self, other) -> bool: + return set(self._generate_comparer(self)) <= set(self._generate_comparer(other)) + + @catch_not_implemented + def is_ge_cell(self, other) -> bool: + return set(self._generate_comparer(self)) >= set(self._generate_comparer(other)) + + @catch_not_implemented + def is_equal_cell(self, other) -> bool: # TODO: improve naming + return set(self._generate_comparer(self)) == set(self._generate_comparer(other)) + + def is_ne_cell(self, other) -> bool: + # this does not hold in general, but here we use finite sets + return not self.is_equal_cell(other) + + def get_state(self): + if isinstance(self, AtomicCell): + return self.atoms_state + elif hasattr(self, 'particles_state'): + return self.particles_state + else: + raise AttributeError( + 'The class does not have atoms_state or particles_state' + ) + + def get(self, key, logger=None): + if key == 'state': + return self.get_state() + elif key == 'labels': + if logger is None: + raise ValueError('Logger is not set') + if isinstance(self, AtomicCell): + return self.get_chemical_symbols(logger) + elif isinstance(self, ParticleCell): + return self.get_particle_types(logger) + else: + return None + else: + raise KeyError(f"Key '{key}' not found in Cell") + + def to_ase_atoms(self, logger: 'BoundLogger') -> 'Optional[ase.Atoms]': """ - Check if the cell is equal to an`other` cell by comparing the `positions`. + Generates an ASE Atoms object with the most basic information from the parsed `Cell` + section (labels, periodic_boundary_conditions, positions, and lattice_vectors). + Args: - other: The other cell to compare with. + logger (BoundLogger): The logger to log messages. + Returns: - bool: True if the cells are equal, False otherwise. + (Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`. + """ + + labels = self.get('labels', logger) + if labels is None: + # ! Check the scope of this message + logger.error( + 'Could not find `Cell.state.labels`.' + 'Using ase functionalities with `X` as labels.' + 'This is normal for non-atomic particles,' + 'but no particle labels will be stored.' + ) + labels = ['X'] * len(self.positions) if self.positions is not None else None + # Initialize ase.Atoms object with labels + # ! We need to make sure that the labels from ase.Atoms are not being used downstream! + else: + try: + symbols2numbers(labels) + except KeyError: + logger.warning( + 'Non chemical symbols in `Cell.state.labels`.' + 'Using ase functionalities with `X` as labels.' + 'This is normal for non-atomic particles.' + ) + labels = ['X'] * len(labels) + ase_atoms = ase.Atoms(symbols=labels) + + # PBC + if self.periodic_boundary_conditions is None: + logger.info( + 'Could not find `Cell.periodic_boundary_conditions`. They will be set to [False, False, False].' + ) + self.periodic_boundary_conditions = [False, False, False] + ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions) + + # Lattice vectors + if self.lattice_vectors is not None: + ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude) + else: + logger.info('Could not find `AtomicCell.lattice_vectors`.') + + # Positions + if self.positions is not None: + if len(self.positions) != len(self.atoms_state): + logger.error( + 'Length of `Cell.positions` does not coincide with the length of the `Cell.`.' + ) + return None + ase_atoms.set_positions( + newpositions=self.positions.to('angstrom').magnitude + ) + else: + logger.warning('Could not find `AtomicCell.positions`.') + return None + + return ase_atoms + + def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None: + """ + Parses the information from an ASE Atoms object to the `Cell` section. + + Args: + ase_atoms (ase.Atoms): The ASE Atoms object to parse. + logger (BoundLogger): The logger to log messages. """ - # TODO implement checks on `lattice_vectors` and other quantities to ensure the equality of primitive cells - if not isinstance(other, Cell): - return False + if isinstance(self, AtomicCell): + # `AtomsState[*].chemical_symbol` + for symbol in ase_atoms.get_chemical_symbols(): + atom_state = AtomsState(chemical_symbol=symbol) + self.atoms_state.append(atom_state) + # TODO: implement for `ParticleCell` - # If the `positions` are empty, return False - if self.positions is None or other.positions is None: - return False + # `periodic_boundary_conditions` + self.periodic_boundary_conditions = ase_atoms.get_pbc() - # The `positions` should have the same length (same number of positions) - if len(self.positions) != len(other.positions): - return False - n_positions = len(self.positions) + # `lattice_vectors` + cell = ase_atoms.get_cell() + self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom') - check_positions = self._check_positions( - positions_1=self.positions, positions_2=other.positions - ) - if len(check_positions) != n_positions: - return False - return True + # `positions` + positions = ase_atoms.get_positions() + if ( + not positions.tolist() + ): # ASE assigns a shape=(0, 3) array if no positions are found + return None + self.positions = positions * ureg('angstrom') def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: super().normalize(archive, logger) @@ -361,40 +575,20 @@ def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwarg # Set the name of the section self.name = self.m_def.name - def is_equal_cell(self, other) -> bool: - """ - Check if the atomic cell is equal to an`other` atomic cell by comparing the `positions` and - the `AtomsState[*].chemical_symbol`. - Args: - other: The other atomic cell to compare with. - Returns: - bool: True if the atomic cells are equal, False otherwise. - """ - if not isinstance(other, AtomicCell): - return False - - # Compare positions using the parent sections's `__eq__` method - if not super().is_equal_cell(other=other): - return False - - # Check that the `chemical_symbol` of the atoms in `cell_1` match with the ones in `cell_2` - check_positions = self._check_positions( - positions_1=self.positions, positions_2=other.positions - ) + @staticmethod + def _generate_comparer(obj: 'AtomicCell') -> 'Generator[Any, None, None]': + # presumes `atoms_state` mapping 1-to-1 with `positions` and conserves the order try: - for atom in check_positions: - element_1 = self.atoms_state[atom[0]].chemical_symbol - element_2 = other.atoms_state[atom[1]].chemical_symbol - if element_1 != element_2: - return False - except Exception: - return False - return True + return ( + (HashedPositions(pos), PartialOrderElement(st.chemical_symbol)) + for pos, st in zip(obj.positions, obj.atoms_state) + ) + except AttributeError: + raise NotImplementedError def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]: """ Get the chemical symbols of the atoms in the atomic cell. These are defined on `atoms_state[*].chemical_symbol`. - Args: logger (BoundLogger): The logger to log messages. @@ -412,84 +606,54 @@ def get_chemical_symbols(self, logger: 'BoundLogger') -> list[str]: chemical_symbols.append(atom_state.chemical_symbol) return chemical_symbols - def to_ase_atoms(self, logger: 'BoundLogger') -> Optional[ase.Atoms]: - """ - Generates an ASE Atoms object with the most basic information from the parsed `AtomicCell` - section (labels, periodic_boundary_conditions, positions, and lattice_vectors). + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) - Args: - logger (BoundLogger): The logger to log messages. + # Set the name of the section + self.name = self.m_def.name if self.name is None else self.name - Returns: - (Optional[ase.Atoms]): The ASE Atoms object with the basic information from the `AtomicCell`. - """ - # Initialize ase.Atoms object with labels - atoms_labels = self.get_chemical_symbols(logger=logger) - ase_atoms = ase.Atoms(symbols=atoms_labels) - # PBC - if self.periodic_boundary_conditions is None: - logger.info( - 'Could not find `AtomicCell.periodic_boundary_conditions`. They will be set to [False, False, False].' - ) - self.periodic_boundary_conditions = [False, False, False] - ase_atoms.set_pbc(pbc=self.periodic_boundary_conditions) +# TODO Consider changing name to BeadCell or CGBeadCell, only using "particle" for the more abstract reference to atoms or beads +class ParticleCell(Cell): + """ + A base section used to specify the particle cell information of a system. + """ - # Lattice vectors - if self.lattice_vectors is not None: - ase_atoms.set_cell(cell=self.lattice_vectors.to('angstrom').magnitude) - else: - logger.info('Could not find `AtomicCell.lattice_vectors`.') + particles_state = SubSection(sub_section=ParticlesState.m_def, repeats=True) - # Positions - if self.positions is not None: - if len(self.positions) != len(self.atoms_state): - logger.error( - 'Length of `AtomicCell.positions` does not coincide with the length of the `AtomicCell.atoms_state`.' - ) - return None - ase_atoms.set_positions( - newpositions=self.positions.to('angstrom').magnitude - ) - else: - logger.warning('Could not find `AtomicCell.positions`.') - return None + n_particles = Quantity( + type=np.int32, + description=""" + Number of particles in the particle cell. + """, + ) - return ase_atoms + def __init__(self, m_def: 'Section' = None, m_context: 'Context' = None, **kwargs): + super().__init__(m_def, m_context, **kwargs) + # Set the name of the section + self.name = self.m_def.name - def from_ase_atoms(self, ase_atoms: ase.Atoms, logger: 'BoundLogger') -> None: + def get_particle_types(self, logger: 'BoundLogger') -> list[str]: """ - Parses the information from an ASE Atoms object to the `AtomicCell` section. + Get the chemical symbols of the particle in the particle cell. + These are defined on `particles_state[*].chemical_symbol`. Args: - ase_atoms (ase.Atoms): The ASE Atoms object to parse. logger (BoundLogger): The logger to log messages. - """ - # `AtomsState[*].chemical_symbol` - for symbol in ase_atoms.get_chemical_symbols(): - atom_state = AtomsState(chemical_symbol=symbol) - self.atoms_state.append(atom_state) - # `periodic_boundary_conditions` - self.periodic_boundary_conditions = ase_atoms.get_pbc() - - # `lattice_vectors` - cell = ase_atoms.get_cell() - self.lattice_vectors = ase.geometry.complete_cell(cell) * ureg('angstrom') - - # `positions` - positions = ase_atoms.get_positions() - if ( - not positions.tolist() - ): # ASE assigns a shape=(0, 3) array if no positions are found - return None - self.positions = positions * ureg('angstrom') - - def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: - super().normalize(archive, logger) + Returns: + list: The list of chemical symbols of the particles in the particle cell. + """ + if not self.particles_state: + return [] - # Set the name of the section - self.name = self.m_def.name if self.name is None else self.name + particle_labels = [] + for particle_state in self.particles_state: + if not particle_state.particle_type: + logger.warning('Could not find `ParticlesState[*].particle_type`.') + return [] + particle_labels.append(particle_state.particle_type) + return particle_labels class Symmetry(ArchiveSection): @@ -602,8 +766,11 @@ class Symmetry(ArchiveSection): ) def resolve_analyzed_atomic_cell( - self, symmetry_analyzer: SymmetryAnalyzer, cell_type: str, logger: 'BoundLogger' - ) -> Optional[AtomicCell]: + self, + symmetry_analyzer: 'SymmetryAnalyzer', + cell_type: str, + logger: 'BoundLogger', + ) -> 'Optional[AtomicCell]': """ Resolves the `AtomicCell` section from the `SymmetryAnalyzer` object and the cell_type (primitive or conventional). @@ -647,8 +814,8 @@ def resolve_analyzed_atomic_cell( return atomic_cell def resolve_bulk_symmetry( - self, original_atomic_cell: AtomicCell, logger: 'BoundLogger' - ) -> tuple[Optional[AtomicCell], Optional[AtomicCell]]: + self, original_atomic_cell: 'AtomicCell', logger: 'BoundLogger' + ) -> 'tuple[Optional[AtomicCell], Optional[AtomicCell]]': """ Resolves the symmetry of the material being simulated using MatID and the originally parsed data under original_atomic_cell. It generates two other @@ -860,6 +1027,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: self.m_cache['elemental_composition'] = formula.elemental_composition() +# TODO work on descriptions to reflect the generalization of ModelSystem class ModelSystem(System): """ Model system used as an input for simulating the material. @@ -881,7 +1049,7 @@ class ModelSystem(System): formats. This class nest over itself (with the section proxy in `model_system`) to define different - parent-child system trees. The quantities `branch_label`, `branch_depth`, `atom_indices`, + parent-child system trees. The quantities `branch_label`, `branch_depth`, `particle_indices`, and `bond_list` are used to define the parent-child tree. The normalization is ran in the following order: @@ -1002,27 +1170,29 @@ class ModelSystem(System): """, ) - atom_indices = Quantity( + particle_indices = Quantity( type=np.int32, shape=['*'], description=""" - Indices of the atoms in the child with respect to its parent. Example: + Indices of the atoms or, more generally, particles in the child with respect to its parent. Example: - We have SrTiO3, where `AtomicCell.labels = ['Sr', 'Ti', 'O', 'O', 'O']`. If we create a `model_system` child for the `'Ti'` atom only, then in that child - `ModelSystem.model_system.atom_indices = [1]`. If now we want to refer both to - the `'Ti'` and the last `'O'` atoms, `ModelSystem.model_system.atom_indices = [1, 4]`. + `ModelSystem.model_system.particle_indices = [1]`. If now we want to refer both to + the `'Ti'` and the last `'O'` atoms, `ModelSystem.model_system.particle_indices = [1, 4]`. """, ) - # TODO improve description and add an example using the case in atom_indices + # TODO improve description and add an example using the case in particle_indices bond_list = Quantity( type=np.int32, + shape=['*', 2], description=""" List of pairs of atom indices corresponding to bonds (e.g., as defined by a force field) within this atoms_group. """, ) + # TODO: make this work with non_atomic systems: global_composition_formula of entire system with respect to lower layers composition_formula = Quantity( type=str, description=""" @@ -1115,6 +1285,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: 'Could not find the originally parsed atomic system. `Symmetry` and `ChemicalFormula` extraction is thus not run.' ) return + if self.cell[0].name == 'AtomicCell': self.cell[0].type = 'original' ase_atoms = self.cell[0].to_ase_atoms(logger=logger) @@ -1137,11 +1308,14 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: sec_symmetry = self.m_create(Symmetry) sec_symmetry.normalize(archive, logger) + #! ChemicalFormula calls `ase_atoms = atomic_cell.to_ase_atoms(logger=logger)` and `ase_atoms.get_chemical_formula()` # Creating and normalizing ChemicalFormula section - # TODO add support for fractional formulas (possibly add `AtomicCell.concentrations` for each species) - sec_chemical_formula = self.m_create(ChemicalFormula) - sec_chemical_formula.normalize(archive, logger) - if sec_chemical_formula.m_cache: - self.elemental_composition = sec_chemical_formula.m_cache.get( - 'elemental_composition', [] - ) + if any(cell.name == 'AtomicCell' for cell in self.cell): + # TODO: get_sibling_section() may need to be updated to more specifically search for AtomicCell in ChemicalFormula and Symmetry, in cases where multiple different cells are present + # TODO add support for fractional formulas (possibly add `AtomicCell.concentrations` for each species) + sec_chemical_formula = self.m_create(ChemicalFormula) + sec_chemical_formula.normalize(archive, logger) + if sec_chemical_formula.m_cache: + self.elemental_composition = sec_chemical_formula.m_cache.get( + 'elemental_composition', [] + ) diff --git a/src/nomad_simulations/schema_packages/particles_state.py b/src/nomad_simulations/schema_packages/particles_state.py new file mode 100644 index 00000000..31c0c197 --- /dev/null +++ b/src/nomad_simulations/schema_packages/particles_state.py @@ -0,0 +1,116 @@ +import numbers +from typing import TYPE_CHECKING, Any, Optional, Union + +import ase +import ase.geometry +import numpy as np +import pint + +# from deprecated import deprecated +from nomad.datamodel.data import ArchiveSection +from nomad.datamodel.metainfo.annotations import ELNAnnotation +from nomad.datamodel.metainfo.basesections import Entity +from nomad.metainfo import MEnum, Quantity, SubSection +from nomad.units import ureg + +if TYPE_CHECKING: + from nomad.datamodel.datamodel import EntryArchive + from nomad.metainfo import Context, Section + from structlog.stdlib import BoundLogger + +from nomad_simulations.schema_packages.atoms_state import State + + +# ? How generic (usable for any CG model) vs. Martini-specific do we want to be? +class ParticlesState(State): + """ + A base section to define individual coarse-grained (CG) particle information. + """ + + # ? What do we want to qualify as type identifier? What safety checks do we need? + particle_type = Quantity( + type=str, + description=""" + Symbol(s) describing the CG particle type. Currently, entire particle label is + used for type definition. + """, + ) + + mass = Quantity( + type=np.float64, + unit='kg', + description=""" + Total mass of the particle. + """, + ) + + charge = Quantity( + type=np.float64, + unit='coulomb', + description=""" + Total charge of the particle. + """, + ) + + charge = Quantity( + type=np.float64, + unit='coulomb', + description=""" + Total charge of the particle. + """, + ) + + # Other possible quantities + # diameter: float + # The diameter of each particle. + # Default: 1.0 + # body: int + # The composite body associated with each particle. The value -1 + # indicates no body. + # Default: -1 + # moment_inertia: float + # The moment_inertia of each particle (I_xx, I_yy, I_zz). + # This inertia tensor is diagonal in the body frame of the particle. + # The default value is for point particles. + # Default: 0, 0, 0 + # scaled_positions: list of scaled-positions #! for cell if relevant + # Like positions, but given in units of the unit cell. + # Can not be set at the same time as positions. + # Default: 0, 0, 0 + # orientation: float + # The orientation of each particle. In scalar + vector notation, + # this is (r, a_x, a_y, a_z), where the quaternion is q = r + a_xi + a_yj + a_zk. + # A unit quaternion has the property: sqrt(r^2 + a_x^2 + a_y^2 + a_z^2) = 1. + # Default: 0, 0, 0, 0 + # angmom: float #? for cell or here? + # The angular momentum of each particle as a quaternion. + # Default: 0, 0, 0, 0 + # image: int #! advance PBC stuff would go in cell I guess + # The number of times each particle has wrapped around the box (i_x, i_y, i_z). + # Default: 0, 0, 0 + + # ? What is the purpose exactly of this function? Example? + def resolve_particle_type(self, logger: 'BoundLogger') -> Optional[str]: + """ + Checks if any value is passed as particle label. Converts to string to be used as + type identifier for the CG particle. + + Args: + logger (BoundLogger): The logger to log messages. + + Returns: + (Optional[str]): The resolved `particle type`. + """ + if self.particle_type is not None and self.particle_type.isascii(): + try: + return str(self.particle_type) + except TypeError: + logger.error('The parsed `particle type` can not be read.') + return None + + def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None: + super().normalize(archive, logger) + + # Get particle_type as string, if possible. + if not isinstance(self.particle_type, str): + self.particle_type = self.resolve_particle_type(logger=logger) diff --git a/src/nomad_simulations/schema_packages/physical_property.py b/src/nomad_simulations/schema_packages/physical_property.py index 5bb728bc..4ce63115 100644 --- a/src/nomad_simulations/schema_packages/physical_property.py +++ b/src/nomad_simulations/schema_packages/physical_property.py @@ -14,7 +14,7 @@ SectionProxy, SubSection, ) -from nomad.metainfo.metainfo import Dimension, DirectQuantity, _placeholder_quantity +from nomad.metainfo.metainfo import Dimension if TYPE_CHECKING: from nomad.datamodel.datamodel import EntryArchive @@ -120,7 +120,7 @@ class PhysicalProperty(ArchiveSection): # ! add more examples in the description to improve the understanding of this quantity ) - rank = DirectQuantity( + rank = Quantity( type=Dimension, shape=['0..*'], default=[], @@ -137,7 +137,7 @@ class PhysicalProperty(ArchiveSection): variables = SubSection(sub_section=Variables.m_def, repeats=True) # * `value` must be overwritten in the derived classes defining its type, unit, and description - value: Quantity = _placeholder_quantity + value: Quantity = None entity_ref = Quantity( type=Entity, diff --git a/src/nomad_simulations/schema_packages/utils/__init__.py b/src/nomad_simulations/schema_packages/utils/__init__.py index 52d9ca22..f9945a34 100644 --- a/src/nomad_simulations/schema_packages/utils/__init__.py +++ b/src/nomad_simulations/schema_packages/utils/__init__.py @@ -1,5 +1,6 @@ from .utils import ( RussellSaundersState, + catch_not_implemented, get_composition, get_sibling_section, get_variables, diff --git a/src/nomad_simulations/schema_packages/utils/utils.py b/src/nomad_simulations/schema_packages/utils/utils.py index 1d40aa4a..6483c428 100644 --- a/src/nomad_simulations/schema_packages/utils/utils.py +++ b/src/nomad_simulations/schema_packages/utils/utils.py @@ -5,7 +5,7 @@ from nomad.config import config if TYPE_CHECKING: - from typing import Optional + from typing import Callable, Optional from nomad.datamodel.data import ArchiveSection from structlog.stdlib import BoundLogger @@ -48,6 +48,7 @@ def get_sibling_section( if not sibling_section_name: logger.warning('The sibling_section_name is empty.') return None + sibling_section = section.m_xpath(f'm_parent.{sibling_section_name}', dict=False) # If the sibling_section is a list, return the element `index_sibling` of that list if isinstance(sibling_section, list): @@ -154,3 +155,19 @@ def get_composition(children_names: 'list[str]') -> str: children_count_tup = np.unique(children_names, return_counts=True) formula = ''.join([f'{name}({count})' for name, count in zip(*children_count_tup)]) return formula if formula else None + + +def catch_not_implemented(func: 'Callable') -> 'Callable': + """ + Decorator to default comparison functions outside the same class to `False`. + """ + + def wrapper(self, other) -> bool: + if not isinstance(other, self.__class__): + return False # ? should this throw an error instead? + try: + return func(self, other) + except (TypeError, NotImplementedError): + return False + + return wrapper diff --git a/tests/test_general.py b/tests/test_general.py index a693f33c..621db4c9 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -87,7 +87,7 @@ def get_flat_depths( assert value == result @pytest.mark.parametrize( - 'is_representative, has_atom_indices, mol_label_list, n_mol_list, atom_labels_list, composition_formula_list, custom_formulas', + 'is_representative, has_particle_indices, mol_label_list, n_mol_list, atom_labels_list, composition_formula_list, custom_formulas', [ ( True, @@ -168,7 +168,7 @@ def get_flat_depths( def test_system_hierarchy_for_molecules( self, is_representative: bool, - has_atom_indices: bool, + has_particle_indices: bool, mol_label_list: list[str], n_mol_list: list[int], atom_labels_list: list[str], @@ -181,8 +181,8 @@ def test_system_hierarchy_for_molecules( Args: is_representative (bool): Specifies if branch_depth = 0 is representative or not. If not representative, the composition formulas should not be generated. - has_atom_indices (bool): Specifies if the atom_indices should be populated during parsing. - Without atom_indices, the composition formulas for the deepest level of the hierarchy + has_particle_indices (bool): Specifies if the particle_indices should be populated during parsing. + Without particle_indices, the composition formulas for the deepest level of the hierarchy should not be populated. mol_label_list (list[str]): Molecule types for generating the hierarchy. n_mol_list (list[int]): Number of molecules for each molecule type. Should be same @@ -212,15 +212,15 @@ def test_system_hierarchy_for_molecules( ctr_comp = 1 atomic_cell = AtomicCell() model_system.cell.append(atomic_cell) - if has_atom_indices: - model_system.atom_indices = [] + if has_particle_indices: + model_system.particle_indices = [] for mol_label, n_mol, atom_labels in zip( mol_label_list, n_mol_list, atom_labels_list ): # Create a branch in the hierarchy for this molecule type model_system_mol_group = ModelSystem() - if has_atom_indices: - model_system_mol_group.atom_indices = [] + if has_particle_indices: + model_system_mol_group.particle_indices = [] model_system_mol_group.branch_label = ( f'group_{mol_label}' if mol_label is not None else None ) @@ -241,14 +241,14 @@ def test_system_hierarchy_for_molecules( AtomsState(chemical_symbol=atom_label) ) n_atoms = len(atomic_cell.atoms_state) - atom_indices = np.arange(n_atoms - len(atom_labels), n_atoms) - if has_atom_indices: - model_system_mol.atom_indices = atom_indices - model_system_mol_group.atom_indices = np.append( - model_system_mol_group.atom_indices, atom_indices + particle_indices = np.arange(n_atoms - len(atom_labels), n_atoms) + if has_particle_indices: + model_system_mol.particle_indices = particle_indices + model_system_mol_group.particle_indices = np.append( + model_system_mol_group.particle_indices, particle_indices ) - model_system.atom_indices = np.append( - model_system.atom_indices, atom_indices + model_system.particle_indices = np.append( + model_system.particle_indices, particle_indices ) simulation.normalize(EntryArchive(), logger) diff --git a/tests/test_model_method.py b/tests/test_model_method.py index 6f6e6393..d352b858 100644 --- a/tests/test_model_method.py +++ b/tests/test_model_method.py @@ -79,7 +79,7 @@ def test_resolve_type(self, tb_section: TB, result: Optional[str]): is_representative=True, cell=[AtomicCell(atoms_state=[AtomsState()])], model_system=[ - ModelSystem(type='active_atom', atom_indices=[2]) + ModelSystem(type='active_atom', particle_indices=[2]) ], ) ], @@ -93,7 +93,7 @@ def test_resolve_type(self, tb_section: TB, result: Optional[str]): is_representative=True, cell=[AtomicCell(atoms_state=[AtomsState(orbitals_state=[])])], model_system=[ - ModelSystem(type='active_atom', atom_indices=[0]) + ModelSystem(type='active_atom', particle_indices=[0]) ], ) ], @@ -117,7 +117,7 @@ def test_resolve_type(self, tb_section: TB, result: Optional[str]): ) ], model_system=[ - ModelSystem(type='active_atom', atom_indices=[0]) + ModelSystem(type='active_atom', particle_indices=[0]) ], ) ], @@ -207,7 +207,7 @@ def test_resolve_orbital_references( is_representative=True, cell=[AtomicCell(atoms_state=[AtomsState()])], model_system=[ - ModelSystem(type='active_atom', atom_indices=[2]) + ModelSystem(type='active_atom', particle_indices=[2]) ], ) ], @@ -222,7 +222,7 @@ def test_resolve_orbital_references( is_representative=True, cell=[AtomicCell(atoms_state=[AtomsState(orbitals_state=[])])], model_system=[ - ModelSystem(type='active_atom', atom_indices=[0]) + ModelSystem(type='active_atom', particle_indices=[0]) ], ) ], @@ -247,7 +247,7 @@ def test_resolve_orbital_references( ) ], model_system=[ - ModelSystem(type='active_atom', atom_indices=[0]) + ModelSystem(type='active_atom', particle_indices=[0]) ], ) ], @@ -272,7 +272,7 @@ def test_resolve_orbital_references( ) ], model_system=[ - ModelSystem(type='active_atom', atom_indices=[0]) + ModelSystem(type='active_atom', particle_indices=[0]) ], ) ], diff --git a/tests/test_model_system.py b/tests/test_model_system.py index f334da23..088ecc6b 100644 --- a/tests/test_model_system.py +++ b/tests/test_model_system.py @@ -18,96 +18,104 @@ from .conftest import generate_atomic_cell -class TestCell: +class TestAtomicCell: """ - Test the `Cell` section defined in model_system.py + Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py """ @pytest.mark.parametrize( 'cell_1, cell_2, result', [ - (Cell(), None, False), # one cell is None - (Cell(), Cell(), False), # both cells are empty + (Cell(), None, {'lt': False, 'gt': False, 'eq': False}), # one cell is None + # (Cell(), Cell(), False), # both cells are empty + # ( + # Cell(positions=[[1, 0, 0]]), + # Cell(), + # False, + # ), # one cell has positions, the other is empty ( Cell(positions=[[1, 0, 0]]), - Cell(), - False, - ), # one cell has positions, the other is empty + Cell(positions=[[2, 0, 0]]), + {'lt': False, 'gt': False, 'eq': False}, + ), # position vectors are treated as the fundamental set elements ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), Cell(positions=[[1, 0, 0]]), - False, - ), # length mismatch - ( - Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0], [0, -1, 0]]), - False, - ), # different positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - True, - ), # same ordered positions - ( - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), - True, - ), # different ordered positions but same cell - ], - ) - def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): - """ - Test the `is_equal_cell` methods of `Cell`. - """ - assert cell_1.is_equal_cell(other=cell_2) == result - - -class TestAtomicCell: - """ - Test the `AtomicCell`, `Cell` and `GeometricSpace` classes defined in model_system.py - """ - - @pytest.mark.parametrize( - 'cell_1, cell_2, result', - [ - (Cell(), None, False), # one cell is None - (Cell(), Cell(), False), # both cells are empty + {'lt': False, 'gt': True, 'eq': False}, + ), # one is a subset of the other ( Cell(positions=[[1, 0, 0]]), - Cell(), - False, - ), # one cell has positions, the other is empty - ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), - Cell(positions=[[1, 0, 0]]), - False, - ), # length mismatch + {'lt': True, 'gt': False, 'eq': False}, + ), # one is a subset of the other ( Cell(positions=[[1, 0, 0], [0, 1, 0]]), Cell(positions=[[1, 0, 0], [0, -1, 0]]), - False, + {'lt': False, 'gt': False, 'eq': False}, ), # different positions ( Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - True, + {'lt': False, 'gt': False, 'eq': True}, ), # same ordered positions ( Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Cell(positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]]), - True, + {'lt': False, 'gt': False, 'eq': True}, ), # different ordered positions but same cell + # ( + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # False, + # ), # one atomic cell and another cell (missing chemical symbols) + # ( + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + # False, + # ), # missing chemical symbols + # ND: the comparison will now return an error here + # handling a case that should be resolved by the normalizer falls outside its scope ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - Cell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # one atomic cell and another cell (missing chemical symbols) + AtomicCell( + positions=[[1, 0, 0]], + atoms_state=[ + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + ], + ), + {'lt': False, 'gt': False, 'eq': False}, + ), # chemical symbols are treated as the fundamental set elements ( - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - AtomicCell(positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]]), - False, - ), # missing chemical symbols + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + ], + ), + {'lt': False, 'gt': True, 'eq': False}, + ), # one is a subset of the other ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + ], + ), AtomicCell( positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], atoms_state=[ @@ -116,6 +124,16 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), + {'lt': True, 'gt': False, 'eq': False}, + ), # one is a subset of the other + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), AtomicCell( positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], atoms_state=[ @@ -124,7 +142,26 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), - True, + {'lt': False, 'gt': False, 'eq': False}, + ), + ( + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + {'lt': False, 'gt': False, 'eq': True}, ), # same ordered positions and chemical symbols ( AtomicCell( @@ -143,7 +180,7 @@ class TestAtomicCell: AtomsState(chemical_symbol='O'), ], ), - False, + {'lt': False, 'gt': False, 'eq': False}, ), # same ordered positions but different chemical symbols ( AtomicCell( @@ -162,38 +199,41 @@ class TestAtomicCell: AtomsState(chemical_symbol='H'), ], ), - True, - ), # different ordered positions but same chemical symbols - ], - ) - def test_is_equal_cell(self, cell_1: Cell, cell_2: Cell, result: bool): - """ - Test the `is_equal_cell` methods of `AtomicCell`. - """ - assert cell_1.is_equal_cell(other=cell_2) == result - - @pytest.mark.parametrize( - 'atomic_cell, result', - [ - (AtomicCell(), []), - (AtomicCell(atoms_state=[AtomsState(chemical_symbol='H')]), ['H']), + {'lt': False, 'gt': False, 'eq': True}, + ), # same position-symbol map, different overall order ( AtomicCell( + positions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]], + atoms_state=[ + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='H'), + AtomsState(chemical_symbol='O'), + ], + ), + AtomicCell( + positions=[[1, 0, 0], [0, 0, 1], [0, 1, 0]], atoms_state=[ AtomsState(chemical_symbol='H'), - AtomsState(chemical_symbol='Fe'), + AtomsState(chemical_symbol='H'), AtomsState(chemical_symbol='O'), - ] + ], ), - ['H', 'Fe', 'O'], - ), + {'lt': False, 'gt': False, 'eq': False}, + ), # different position-symbol map ], ) - def test_get_chemical_symbols(self, atomic_cell: AtomicCell, result: list[str]): + def test_partial_order( + self, cell_1: 'Cell', cell_2: 'Cell', result: dict[str, bool] + ): """ - Test the `get_chemical_symbols` method of `AtomicCell`. + Test the comparison operators of `Cell` and `AtomicCell`. """ - assert atomic_cell.get_chemical_symbols(logger=logger) == result + assert cell_1.is_lt_cell(cell_2) == result['lt'] + assert cell_1.is_gt_cell(cell_2) == result['gt'] + assert cell_1.is_le_cell(cell_2) == (result['lt'] or result['eq']) + assert cell_1.is_ge_cell(cell_2) == (result['gt'] or result['eq']) + assert cell_1.is_equal_cell(cell_2) == result['eq'] + assert cell_1.is_ne_cell(cell_2) == (not result['eq']) @pytest.mark.parametrize( 'chemical_symbols, atomic_numbers, formula, lattice_vectors, positions, periodic_boundary_conditions',