Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for aiida-atomistic #178

Merged
merged 2 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions src/aiida_pseudo/groups/family/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@

__all__ = ('PseudoPotentialFamily',)

StructureData = DataFactory('core.structure')
LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name

try:
StructureData = DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)


class PseudoPotentialFamily(Group):
Expand Down Expand Up @@ -308,12 +315,12 @@ def get_pseudos(
self,
*,
elements: Optional[Union[List[str], Tuple[str]]] = None,
structure: StructureData = None,
) -> Mapping[str, StructureData]:
structure: Union[structures_classes] = None,
) -> Mapping[str, Union[structures_classes]]:
"""Return the mapping of kind names on pseudo potential data nodes for the given list of elements or structure.

:param elements: list of element symbols.
:param structure: the ``StructureData`` node.
:param structure: the ``StructureData`` or ``LegacyStructureData`` node.
:return: dictionary mapping the kind names of a structure on the corresponding pseudo potential data nodes.
:raises ValueError: if the family does not contain a pseudo for any of the elements of the given structure.
"""
Expand All @@ -323,11 +330,11 @@ def get_pseudos(
if elements is None and structure is None:
raise ValueError('have to specify one of the keyword arguments `elements` and `structure`.')

if elements is not None and not isinstance(elements, (list, tuple)) and not isinstance(elements, StructureData):
if elements is not None and not isinstance(elements, (list, tuple)):
raise ValueError('elements should be a list or tuple of symbols.')

if structure is not None and not isinstance(structure, StructureData):
raise ValueError('structure should be a `StructureData` instance.')
if structure is not None and not (isinstance(structure, structures_classes)):
raise ValueError(f'structure is of type {type(structure)} but should be of: {structures_classes}')

if structure is not None:
return {kind.name: self.get_pseudo(kind.symbol) for kind in structure.kinds}
Expand Down
12 changes: 10 additions & 2 deletions src/aiida_pseudo/groups/mixins/cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
import warnings
from typing import Optional

from aiida.common.exceptions import MissingEntryPointError
from aiida.common.lang import type_check
from aiida.plugins import DataFactory

from aiida_pseudo.common.units import U

StructureData = DataFactory('core.structure')
LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name
mikibonacci marked this conversation as resolved.
Show resolved Hide resolved

try:
StructureData = DataFactory('atomistic.structure')
except MissingEntryPointError:
structures_classes = (LegacyStructureData,)
else:
structures_classes = (LegacyStructureData, StructureData)

__all__ = ('RecommendedCutoffMixin',)

Expand Down Expand Up @@ -278,7 +286,7 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N
raise ValueError('at least one and only one of `elements` or `structure` should be defined')

type_check(elements, (tuple, str), allow_none=True)
type_check(structure, StructureData, allow_none=True)
type_check(structure, (structures_classes), allow_none=True)

if unit is not None:
self.validate_cutoffs_unit(unit)
Expand Down
52 changes: 51 additions & 1 deletion tests/groups/family/test_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,19 @@
import pytest
from aiida.common import exceptions
from aiida.orm import QueryBuilder
from aiida.plugins import DataFactory
from aiida_pseudo.data.pseudo import PseudoPotentialData
from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily

try:
DataFactory('atomistic.structure')
except exceptions.MissingEntryPointError:
has_atomistic = False
else:
has_atomistic = True

skip_atomistic = pytest.mark.skipif(not has_atomistic, reason='Unable to import aiida-atomistic')


def test_type_string():
"""Verify the `_type_string` class attribute is correctly set to the corresponding entry point name."""
Expand Down Expand Up @@ -408,7 +418,9 @@ def test_get_pseudos_raise(get_pseudo_family, generate_structure):
with pytest.raises(ValueError, match='elements should be a list or tuple of symbols.'):
family.get_pseudos(elements={'He', 'Ar'})

with pytest.raises(ValueError, match='structure should be a `StructureData` instance.'):
with pytest.raises(
ValueError, match=r"but should be of: \(<class 'aiida.orm.nodes.data.structure.StructureData'>,"
):
family.get_pseudos(structure={'He', 'Ar'})

with pytest.raises(ValueError, match=r'family `.*` does not contain pseudo for element `.*`'):
Expand Down Expand Up @@ -454,3 +466,41 @@ def test_get_pseudos_structure_kinds(get_pseudo_family, generate_structure):
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)


@skip_atomistic
@pytest.mark.usefixtures('aiida_profile_clean')
def test_get_pseudos_atomsitic_structure(get_pseudo_family, generate_structure):
"""
Test the `PseudoPotentialFamily.get_pseudos` method when passing
an aiida-atomistic ``StructureData`` instance.
"""

elements = ('Ar', 'He', 'Ne')
orm_structure = generate_structure(elements)
structure = orm_structure.to_atomistic()
family = get_pseudo_family(elements=elements)

pseudos = family.get_pseudos(structure=structure)
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)


@skip_atomistic
@pytest.mark.usefixtures('aiida_profile_clean')
def test_get_pseudos_atomistic_structure_kinds(get_pseudo_family, generate_structure):
"""
Test the `PseudoPotentialFamily.get_pseudos` for
an aiida-atomistic ``StructureData`` with kind names including digits.
"""

elements = ('Ar1', 'Ar2')
orm_structure = generate_structure(elements)
structure = orm_structure.to_atomistic()
family = get_pseudo_family(elements=elements)

pseudos = family.get_pseudos(structure=structure)
assert isinstance(pseudos, dict)
for element in elements:
assert isinstance(pseudos[element], PseudoPotentialData)