Skip to content

Commit

Permalink
Fix parsing with example from Vikrant
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Sep 27, 2024
1 parent 7786e9e commit bbe5357
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 62 deletions.
4 changes: 4 additions & 0 deletions src/nomad_parser_wannier90/parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class SimulationParserEntryPoint(ParserEntryPoint):
Order of execution of parser with respect to other parsers.
""",
)
equal_cell_positions_tolerance: float = Field(
1e-2,
description='Tolerance (in angstroms) for the cell positions to be considered equal.',
)

def load(self):
from nomad.parsing import MatchingParserInterface
Expand Down
196 changes: 134 additions & 62 deletions src/nomad_parser_wannier90/parsers/win_parser.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,49 @@
from functools import wraps
from typing import TYPE_CHECKING, Optional, Union

if TYPE_CHECKING:
from structlog.stdlib import BoundLogger

import re

import numpy as np
from nomad.config import config
from nomad.parsing.file_parser import Quantity, TextParser
from nomad_simulations.schema_packages.atoms_state import OrbitalsState
from nomad_simulations.schema_packages.model_system import AtomicCell, ModelSystem

configuration = config.get_plugin_entry_point(
'nomad_parser_wannier90.parsers:parser_entry_point'
)


def validate_atomic_cell(func):
@wraps(func)
def wrapper(
self,
position: list[float],
atomic_cell: Optional['AtomicCell'],
units: str,
*args,
**kwargs,
):
if (
atomic_cell is None
or atomic_cell.atoms_state is None
or len(atomic_cell.atoms_state) == 0
):
print('Invalid atomic cell: either None or contains no `AtomsState`.')
return [], [] # Return an empty tuple (or handle the error differently)
return func(self, position, atomic_cell, units, *args, **kwargs)

return wrapper


class WInParser(TextParser):
def init_quantities(self):
def str_proj_to_list(val_in):
# To avoid inconsistent regex that can contain or not spaces
val_n = [x for x in val_in.split('\n') if x]
val_n = [re.sub(r'\s.*', '', x) for x in val_in.split('\n') if x]
return [v.strip('[]').replace(' ', '').split(':') for v in val_n]

self._quantities = [
Expand Down Expand Up @@ -74,75 +104,111 @@ def __init__(self, win_file: str = ''):
(3, 6): ('f', 'x(x^2-3y^2)'),
(3, 7): ('f', 'y(3x^2-y^2)'),
}
# Only angular momentum [l] (degenerate in mr)
self._wannier_l_orbital_map = {
0: 's',
1: 'p',
2: 'd',
3: 'f',
}

def _convert_positions_to_symbols(
self, atomic_cell: AtomicCell, units: str, positions: list[float]
) -> Optional[str]:
@validate_atomic_cell
def _convert_position(
self, position: list[float], atomic_cell: AtomicCell, units: str
) -> tuple[list, list]:
"""
Convert the atom `positions` in fractional or cartesian coordinates to the atom `chemical_symbols`.
Args:
atomic_cell (AtomicCell): The `AtomicCell` section to which `positions` are extracted
position (list[float]): The position in fractional or cartesian coordinates.
atomic_cell (AtomicCell): The `AtomicCell` section from which `positions` are extracted
units (str): The units in which the positions are defined.
positions (list[float]): The positions in fractional or cartesian coordinates to be converted to `chemical_symbols`.
Returns:
(Optional[str]): The `chemical_symbols` of the atom at the position `val`.
tuple[list, list]: The indices and symbols at which the `position` coincides with the `AtomicCell.positions[*]`.
"""
for cell_position in atomic_cell.positions.to(units):
if np.array_equal(positions, cell_position.magnitude):
index = atomic_cell.positions.magnitude.tolist().index(
cell_position.magnitude.tolist()
)
return atomic_cell.atoms_state[index].chemical_symbol
return None
indices = []
symbols = []
for index, cell_position in enumerate(atomic_cell.positions.to(units)):
if np.allclose(
position,
cell_position.magnitude,
configuration.equal_cell_positions_tolerance,
):
indices.append(index)
symbols.append(atomic_cell.atoms_state[index].chemical_symbol)
return indices, symbols

def _get_f_information(
self, atom: str, atomic_cell: AtomicCell, units: str
) -> tuple[list, list]:
position = [float(x) for x in atom.replace('f=', '').split(',')]
position = np.dot(position, atomic_cell.lattice_vectors.magnitude)
return self._convert_position(
position=position, atomic_cell=atomic_cell, units=units
)

def _get_c_information(
self, atom: str, atomic_cell: AtomicCell, units: str
) -> tuple[list, list]:
position = [float(x) for x in atom.replace('c=', '').split(',')]
return self._convert_position(
position=position, atomic_cell=atomic_cell, units=units
)

def parse_child_atom_indices(
def get_branch_label_and_atom_indices(
self,
atom: Union[str, int],
atomic_cell: AtomicCell,
units: str,
) -> tuple[Optional[str], list[int]]:
"""
Parse the atom indices for the child model system.
Gets the branch label and the atom indices for the child model system.
Args:
atom (Union[str, int]): The atom string containing the positions information. In some older version,
this can be an integer index pointing to the atom.
this can be an integer index pointing to the atom (which is very buggy).
atomic_cell (AtomicCell): The `AtomicCell` section where `positions` are stored
units (str): The units in which the positions are defined.
Returns:
(tuple[str, list[int]]): The `branch_label` and `atom_indices` for the child model system.
tuple[str, list[int]]: The `branch_label` and `atom_indices` for the child model system.
"""
# Initial check for bugs when `atom` is an integer
if isinstance(atom, int):
return '', [atom]

# 3 different cases to define in `win`
symbols = ''
indices = []
# If the atom is not a chemical element, we use the `_convert_position` method resolution for it, joining the `symbols` into a long string
if atom.startswith('f='): # fractional coordinates
positions = [float(x) for x in atom.replace('f=', '').split(',')]
positions = np.dot(positions, atomic_cell.lattice_vectors.magnitude)
sites = self._convert_positions_to_symbols(
atomic_cell=atomic_cell, units=units, positions=positions
indices, symbols = self._get_f_information(
atom=atom, atomic_cell=atomic_cell, units=units
)
elif atom.startswith('c='): # cartesian coordinates
positions = [float(x) for x in atom.replace('c=', '').split(',')]
sites = self._convert_positions_to_symbols(
atomic_cell=atomic_cell, units=units, positions=positions
indices, symbols = self._get_c_information(
atom=atom, atomic_cell=atomic_cell, units=units
)
# Otherwise, if the atom chemical symbol is directly specified, we store all the `atom_indices` coinciding with this label
else: # atom label directly specified
sites = atom
atom_indices = np.where(
[
atom_state.chemical_symbol == atom
for atom_state in atomic_cell.atoms_state
]
)[0].tolist()
return atom, atom_indices

# Find the `atom_indices` which coincide with the `sites`
branch_label = sites
atom_indices = np.where(
[atom.chemical_symbol == branch_label for atom in atomic_cell.atoms_state]
)[0].tolist()
return branch_label, atom_indices
branch_label = ''.join(symbols)
return branch_label, indices

def populate_orbitals_state(
self,
projection: list[str],
model_system_child: ModelSystem,
atomic_cell: AtomicCell,
units: str,
logger: 'BoundLogger',
) -> None:
"""
Expand All @@ -154,42 +220,47 @@ def populate_orbitals_state(
atomic_cell (AtomicCell): The `AtomicCell` section where `positions` are stored.
logger (BoundLogger): The logger to log messages.
"""
# Bug when `atom` is an integer
atom = projection[0]
for atom_index in model_system_child.atom_indices:
if atomic_cell.atoms_state is None or len(atomic_cell.atoms_state) == 0:
logger.warning(
'Could not extract the `AtomicCell.atoms_state` sections.'
)
continue
atom_state = atomic_cell.atoms_state[atom_index]
if isinstance(atom, int):
return '', [atom]

# To avoid issues when `atom` is an integer
if isinstance(atom, str) and atom != atom_state.chemical_symbol:
continue
# Initial check for the `atom` and their `indices`
indices = []
if atom.startswith('f='):
indices, _ = self._get_f_information(
atom=atom, atomic_cell=atomic_cell, units=units
)
elif atom.startswith('c='):
indices, _ = self._get_c_information(
atom=atom, atomic_cell=atomic_cell, units=units
)
else:
indices = model_system_child.atom_indices
if not indices:
logger.warning('Could not extract the `AtomicCell.atoms_state` sections.')
return None

# Try to get the orbitals information
try:
orbitals = projection[1].split(';')
angular_momentum = None
for orb in orbitals:
orbital_state = OrbitalsState()
if orb.startswith('l='): # using angular momentum numbers
lmom = int(orb.split(',mr')[0].replace('l=', '').split(',')[0])
orbitals = projection[1].split(';')
for atom_index in model_system_child.atom_indices:
atom_state = atomic_cell.atoms_state[atom_index]
for orb in orbitals:
orbital_state = OrbitalsState()
if orb.startswith('l='): # using angular momentum numbers
lmom = int(orb.split(',mr')[0].replace('l=', '').split(',')[0])
if len(orb.split(',mr')) > 1:
mrmom = int(orb.split(',mr')[-1].replace('=', '').split(',')[0])
angular_momentum = self._wannier_orbital_numbers_map.get(
(lmom, mrmom)
)
else: # ang mom label directly specified
angular_momentum = self._wannier_orbital_symbols_map.get(orb)
(
orbital_state.l_quantum_symbol,
orbital_state.ml_quantum_symbol,
) = angular_momentum
atom_state.orbitals_state.append(orbital_state)
except Exception:
logger.warning('Projected orbital labels not found from win.')
return None
return None
else:
angular_momentum = (self._wannier_l_orbital_map.get(lmom), None)
else: # angular_momentum label directly specified
angular_momentum = self._wannier_orbital_symbols_map.get(orb)
orbital_state.l_quantum_symbol, orbital_state.ml_quantum_symbol = (
angular_momentum
)
atom_state.orbitals_state.append(orbital_state)

def parse_child_model_systems(
self,
Expand Down Expand Up @@ -238,7 +309,7 @@ def parse_child_model_systems(
projection = projections[nat]
atom = projection[0]
try:
branch_label, atom_indices = self.parse_child_atom_indices(
branch_label, atom_indices = self.get_branch_label_and_atom_indices(
atom=atom, atomic_cell=atomic_cell, units=wannier90_units
)
model_system_child.branch_label = branch_label
Expand All @@ -255,6 +326,7 @@ def parse_child_model_systems(
model_system_child=model_system_child,
atomic_cell=atomic_cell,
logger=logger,
units=wannier90_units,
)
model_system_childs.append(model_system_child)

Expand Down

0 comments on commit bbe5357

Please sign in to comment.