diff --git a/src/aiida_pseudo/groups/family/pseudo.py b/src/aiida_pseudo/groups/family/pseudo.py index 4463668..d2e4a6a 100644 --- a/src/aiida_pseudo/groups/family/pseudo.py +++ b/src/aiida_pseudo/groups/family/pseudo.py @@ -11,7 +11,14 @@ __all__ = ('PseudoPotentialFamily',) -StructureData = DataFactory('core.structure') +LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name + +try: + StructureData = DataFactory('atomistic.structure') +except exceptions.MissingEntryPointError: + structures_classes = (LegacyStructureData,) +else: + structures_classes = (LegacyStructureData, StructureData) class PseudoPotentialFamily(Group): @@ -308,12 +315,12 @@ def get_pseudos( self, *, elements: Optional[Union[List[str], Tuple[str]]] = None, - structure: StructureData = None, - ) -> Mapping[str, StructureData]: + structure: Union[structures_classes] = None, + ) -> Mapping[str, Union[structures_classes]]: """Return the mapping of kind names on pseudo potential data nodes for the given list of elements or structure. :param elements: list of element symbols. - :param structure: the ``StructureData`` node. + :param structure: the ``StructureData`` or ``LegacyStructureData`` node. :return: dictionary mapping the kind names of a structure on the corresponding pseudo potential data nodes. :raises ValueError: if the family does not contain a pseudo for any of the elements of the given structure. """ @@ -323,11 +330,11 @@ def get_pseudos( if elements is None and structure is None: raise ValueError('have to specify one of the keyword arguments `elements` and `structure`.') - if elements is not None and not isinstance(elements, (list, tuple)) and not isinstance(elements, StructureData): + if elements is not None and not isinstance(elements, (list, tuple)): raise ValueError('elements should be a list or tuple of symbols.') - if structure is not None and not isinstance(structure, StructureData): - raise ValueError('structure should be a `StructureData` instance.') + if structure is not None and not (isinstance(structure, structures_classes)): + raise ValueError(f'structure is of type {type(structure)} but should be of: {structures_classes}') if structure is not None: return {kind.name: self.get_pseudo(kind.symbol) for kind in structure.kinds} diff --git a/src/aiida_pseudo/groups/mixins/cutoffs.py b/src/aiida_pseudo/groups/mixins/cutoffs.py index e19ba92..3f37f57 100644 --- a/src/aiida_pseudo/groups/mixins/cutoffs.py +++ b/src/aiida_pseudo/groups/mixins/cutoffs.py @@ -2,12 +2,20 @@ import warnings from typing import Optional +from aiida.common.exceptions import MissingEntryPointError from aiida.common.lang import type_check from aiida.plugins import DataFactory from aiida_pseudo.common.units import U -StructureData = DataFactory('core.structure') +LegacyStructureData = DataFactory('core.structure') # pylint: disable=invalid-name + +try: + StructureData = DataFactory('atomistic.structure') +except MissingEntryPointError: + structures_classes = (LegacyStructureData,) +else: + structures_classes = (LegacyStructureData, StructureData) __all__ = ('RecommendedCutoffMixin',) @@ -278,7 +286,7 @@ def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=N raise ValueError('at least one and only one of `elements` or `structure` should be defined') type_check(elements, (tuple, str), allow_none=True) - type_check(structure, StructureData, allow_none=True) + type_check(structure, (structures_classes), allow_none=True) if unit is not None: self.validate_cutoffs_unit(unit) diff --git a/tests/groups/family/test_pseudo.py b/tests/groups/family/test_pseudo.py index 9ee5276..d736944 100644 --- a/tests/groups/family/test_pseudo.py +++ b/tests/groups/family/test_pseudo.py @@ -4,9 +4,19 @@ import pytest from aiida.common import exceptions from aiida.orm import QueryBuilder +from aiida.plugins import DataFactory from aiida_pseudo.data.pseudo import PseudoPotentialData from aiida_pseudo.groups.family.pseudo import PseudoPotentialFamily +try: + DataFactory('atomistic.structure') +except exceptions.MissingEntryPointError: + has_atomistic = False +else: + has_atomistic = True + +skip_atomistic = pytest.mark.skipif(not has_atomistic, reason='Unable to import aiida-atomistic') + def test_type_string(): """Verify the `_type_string` class attribute is correctly set to the corresponding entry point name.""" @@ -408,7 +418,9 @@ def test_get_pseudos_raise(get_pseudo_family, generate_structure): with pytest.raises(ValueError, match='elements should be a list or tuple of symbols.'): family.get_pseudos(elements={'He', 'Ar'}) - with pytest.raises(ValueError, match='structure should be a `StructureData` instance.'): + with pytest.raises( + ValueError, match=r"but should be of: \(," + ): family.get_pseudos(structure={'He', 'Ar'}) with pytest.raises(ValueError, match=r'family `.*` does not contain pseudo for element `.*`'): @@ -454,3 +466,41 @@ def test_get_pseudos_structure_kinds(get_pseudo_family, generate_structure): assert isinstance(pseudos, dict) for element in elements: assert isinstance(pseudos[element], PseudoPotentialData) + + +@skip_atomistic +@pytest.mark.usefixtures('aiida_profile_clean') +def test_get_pseudos_atomsitic_structure(get_pseudo_family, generate_structure): + """ + Test the `PseudoPotentialFamily.get_pseudos` method when passing + an aiida-atomistic ``StructureData`` instance. + """ + + elements = ('Ar', 'He', 'Ne') + orm_structure = generate_structure(elements) + structure = orm_structure.to_atomistic() + family = get_pseudo_family(elements=elements) + + pseudos = family.get_pseudos(structure=structure) + assert isinstance(pseudos, dict) + for element in elements: + assert isinstance(pseudos[element], PseudoPotentialData) + + +@skip_atomistic +@pytest.mark.usefixtures('aiida_profile_clean') +def test_get_pseudos_atomistic_structure_kinds(get_pseudo_family, generate_structure): + """ + Test the `PseudoPotentialFamily.get_pseudos` for + an aiida-atomistic ``StructureData`` with kind names including digits. + """ + + elements = ('Ar1', 'Ar2') + orm_structure = generate_structure(elements) + structure = orm_structure.to_atomistic() + family = get_pseudo_family(elements=elements) + + pseudos = family.get_pseudos(structure=structure) + assert isinstance(pseudos, dict) + for element in elements: + assert isinstance(pseudos[element], PseudoPotentialData)