Skip to content

Commit

Permalink
Adding support for aiida-atomistic
Browse files Browse the repository at this point in the history
- in pseudo and cutoff we introduce the naming `LegacyStructureData` for th `orm.StructureData`.
- two additional tests, triggered only when we have atomistic installed (`HA` is `True`).
  • Loading branch information
mikibonacci committed Nov 22, 2024
1 parent 429f7a3 commit f313d49
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 10 deletions.
24 changes: 17 additions & 7 deletions src/aiida_pseudo/groups/family/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from aiida.common import exceptions
from aiida.common.lang import classproperty, type_check
from aiida.orm import Group, QueryBuilder
from aiida.orm.nodes.data.structure import has_atomistic
from aiida.plugins import DataFactory

from aiida_pseudo.data.pseudo import PseudoPotentialData

__all__ = ('PseudoPotentialFamily',)

StructureData = DataFactory('core.structure')
LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name

#
HA = has_atomistic()
if HA:
StructureData = DataFactory('atomistic.structure')


class PseudoPotentialFamily(Group):
Expand Down Expand Up @@ -308,12 +314,12 @@ def get_pseudos(
self,
*,
elements: Optional[Union[List[str], Tuple[str]]] = None,
structure: StructureData = None,
) -> Mapping[str, StructureData]:
structure: Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData] = None,
) -> Mapping[str, Union[StructureData, LegacyStructureData] if HA else Union[LegacyStructureData]]:
"""Return the mapping of kind names on pseudo potential data nodes for the given list of elements or structure.
:param elements: list of element symbols.
:param structure: the ``StructureData`` node.
:param structure: the ``StructureData`` or ``LegacyStructureData`` node.
:return: dictionary mapping the kind names of a structure on the corresponding pseudo potential data nodes.
:raises ValueError: if the family does not contain a pseudo for any of the elements of the given structure.
"""
Expand All @@ -323,11 +329,15 @@ def get_pseudos(
if elements is None and structure is None:
raise ValueError('have to specify one of the keyword arguments `elements` and `structure`.')

if elements is not None and not isinstance(elements, (list, tuple)) and not isinstance(elements, StructureData):
if elements is not None and not isinstance(elements, (list, tuple)):
raise ValueError('elements should be a list or tuple of symbols.')

if structure is not None and not isinstance(structure, StructureData):
raise ValueError('structure should be a `StructureData` instance.')
if structure is not None and not (
isinstance(structure, (LegacyStructureData, StructureData if HA else LegacyStructureData))
):
raise ValueError(
f'structure should be a `StructureData` or `LegacyStructureData` instance, not {type(structure)}.'
)

if structure is not None:
return {kind.name: self.get_pseudo(kind.symbol) for kind in structure.kinds}
Expand Down
10 changes: 8 additions & 2 deletions src/aiida_pseudo/groups/mixins/cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from typing import Optional

from aiida.common.lang import type_check
from aiida.orm.nodes.data.structure import has_atomistic
from aiida.plugins import DataFactory

from aiida_pseudo.common.units import U

StructureData = DataFactory('core.structure')
LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name

#
HA = has_atomistic()
if HA:
StructureData = DataFactory('atomistic.structure')

__all__ = ('RecommendedCutoffMixin',)

Expand Down Expand Up @@ -278,7 +284,7 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N
raise ValueError('at least one and only one of `elements` or `structure` should be defined')

type_check(elements, (tuple, str), allow_none=True)
type_check(structure, StructureData, allow_none=True)
type_check(structure, (LegacyStructureData, StructureData) if HA else (LegacyStructureData), allow_none=True)

if unit is not None:
self.validate_cutoffs_unit(unit)
Expand Down
43 changes: 42 additions & 1 deletion tests/groups/family/test_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import pytest
from aiida.common import exceptions
from aiida.orm import QueryBuilder
from aiida.orm.nodes.data.structure import has_atomistic
from aiida_pseudo.data.pseudo import PseudoPotentialData
from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily

skip_atomistic = pytest.mark.skipif(not has_atomistic(), reason='Unable to import aiida-atomistic')


def test_type_string():
"""Verify the `_type_string` class attribute is correctly set to the corresponding entry point name."""
Expand Down Expand Up @@ -408,7 +411,7 @@ def test_get_pseudos_raise(get_pseudo_family, generate_structure):
with pytest.raises(ValueError, match='elements should be a list or tuple of symbols.'):
family.get_pseudos(elements={'He', 'Ar'})

with pytest.raises(ValueError, match='structure should be a `StructureData` instance.'):
with pytest.raises(ValueError, match='structure should be a `StructureData` or `LegacyStructureData` instance'):
family.get_pseudos(structure={'He', 'Ar'})

with pytest.raises(ValueError, match=r'family `.*` does not contain pseudo for element `.*`'):
Expand Down Expand Up @@ -454,3 +457,41 @@ def test_get_pseudos_structure_kinds(get_pseudo_family, generate_structure):
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)


@skip_atomistic
@pytest.mark.usefixtures('aiida_profile_clean')
def test_get_pseudos_atomsitic_structure(get_pseudo_family, generate_structure):
"""
Test the `PseudoPotentialFamily.get_pseudos` method when passing
an aiida-atomistic ``StructureData`` instance.
"""

elements = ('Ar', 'He', 'Ne')
orm_structure = generate_structure(elements)
structure = orm_structure.to_atomistic()
family = get_pseudo_family(elements=elements)

pseudos = family.get_pseudos(structure=structure)
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)


@skip_atomistic
@pytest.mark.usefixtures('aiida_profile_clean')
def test_get_pseudos_atomistic_structure_kinds(get_pseudo_family, generate_structure):
"""
Test the `PseudoPotentialFamily.get_pseudos` for
an aiida-atomistic ``StructureData`` with kind names including digits.
"""

elements = ('Ar1', 'Ar2')
orm_structure = generate_structure(elements)
structure = orm_structure.to_atomistic()
family = get_pseudo_family(elements=elements)

pseudos = family.get_pseudos(structure=structure)
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)

0 comments on commit f313d49

Please sign in to comment.