Skip to content

Commit

Permalink
Use custom WorkflowFactory to provide plugin install instructions
Browse files Browse the repository at this point in the history
The `WorkflowFactory` from `aiida-core` is replaced with a custom
version in the `aiida_common_workflows.plugins.factories` module. This
function will call the factory from `aiida-core` but catch the
`MissingEntryPointError` exception. In this case, if the entry point
corresponds to a plugin implementation of one of the common workflows
the exception is reraised but with a useful message that provides the
user with the install command to install the necessary plugin package.

While this should catch all cases of users trying to load a workflow for
a plugin that is not installed through its entry point, it won't catch
import errors that are raised when a module is imported directly from
that plugin package. Therefore, these imports should not be placed at
the top of modules, but placed inside functions/methods of the
implementation as much as possible.
  • Loading branch information
sphuber committed Mar 4, 2024
1 parent 0d5f228 commit 340dbee
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 19 deletions.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ If more flexibility is required, it is advised to write a custom launch script,
.. code:: python
from aiida.engine import submit
from aiida.plugins import WorkflowFactory
from aiida_common_workflows.plugins import WorkflowFactory
RelaxWorkChain = WorkflowFactory('common_workflows.relax.quantum_espresso') # Load the relax workflow implementation of choice.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/workflows/base/relax/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ A typical script for the submission of common relax workflow could look somethin
.. code:: python
from aiida.engine import submit
from aiida.plugins import WorkflowFactory
from aiida_common_workflows.plugins import WorkflowFactory
RelaxWorkChain = WorkflowFactory('common_workflows.relax.<implementation>') # Load the relax workflow implementation of choice.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/workflows/composite/dc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ A typical script for the submission of common DC workflow could look something l
from aiida.orm import List, Dict
from aiida.engine import submit
from aiida.plugins import WorkflowFactory
from aiida_common_workflows.plugins import WorkflowFactory
cls = WorkflowFactory('common_workflows.dissociation_curve')
Expand Down
2 changes: 1 addition & 1 deletion docs/source/workflows/composite/eos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ A typical script for the submission of common EoS workflow could look something
from aiida.orm import List, Dict
from aiida.engine import submit
from aiida.plugins import WorkflowFactory
from aiida_common_workflows.plugins import WorkflowFactory
cls = WorkflowFactory('common_workflows.eos')
Expand Down
8 changes: 7 additions & 1 deletion src/aiida_common_workflows/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Module with utilities for working with the plugins provided by this plugin package."""
from .entry_point import get_entry_point_name_from_class, get_workflow_entry_point_names, load_workflow_entry_point
from .factories import WorkflowFactory

__all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point')
__all__ = (
'WorkflowFactory',
'get_workflow_entry_point_names',
'get_entry_point_name_from_class',
'load_workflow_entry_point',
)
6 changes: 4 additions & 2 deletions src/aiida_common_workflows/plugins/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from aiida.plugins import entry_point

from .factories import WorkflowFactory

PACKAGE_PREFIX = 'common_workflows'

__all__ = ('get_workflow_entry_point_names', 'get_entry_point_name_from_class', 'load_workflow_entry_point')
Expand Down Expand Up @@ -38,5 +40,5 @@ def load_workflow_entry_point(workflow: str, plugin_name: str):
:param plugin_name: name of the plugin implementation.
:return: the workchain class of the plugin implementation of the common workflow.
"""
prefix = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}'
return entry_point.load_entry_point('aiida.workflows', prefix)
entry_point_name = f'{PACKAGE_PREFIX}.{workflow}.{plugin_name}'
return WorkflowFactory(entry_point_name)
47 changes: 47 additions & 0 deletions src/aiida_common_workflows/plugins/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Factories to load entry points."""
import typing as t

from aiida import plugins
from aiida.common import exceptions

if t.TYPE_CHECKING:
from aiida.engine import WorkChain
from importlib_metadata import EntryPoint

__all__ = ('WorkflowFactory',)


@t.overload
def WorkflowFactory(entry_point_name: str, load: t.Literal[True] = True) -> t.Union[t.Type['WorkChain'], t.Callable]:
...


@t.overload
def WorkflowFactory(entry_point_name: str, load: t.Literal[False]) -> 'EntryPoint':
...


def WorkflowFactory(entry_point_name: str, load: bool = True) -> t.Union['EntryPoint', t.Type['WorkChain'], t.Callable]: # noqa: N802
"""Return the `WorkChain` sub class registered under the given entry point.
:param entry_point_name: the entry point name.
:param load: if True, load the matched entry point and return the loaded resource instead of the entry point itself.
:return: sub class of :py:class:`~aiida.engine.processes.workchains.workchain.WorkChain` or a `workfunction`
:raises aiida.common.MissingEntryPointError: entry point was not registered
:raises aiida.common.MultipleEntryPointError: entry point could not be uniquely resolved
:raises aiida.common.LoadingEntryPointError: entry point could not be loaded
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
common_workflow_prefixes = ('common_workflows.relax.', 'common_workflows.bands.')
try:
return plugins.WorkflowFactory(entry_point_name, load)
except exceptions.MissingEntryPointError as exception:
for prefix in common_workflow_prefixes:
if entry_point_name.startswith(prefix):
plugin_name = entry_point_name.removeprefix(prefix)
raise exceptions.MissingEntryPointError(
f'Could not load the entry point `{entry_point_name}`, probably because the plugin package is not '
f'installed. Please install it with `pip install aiida-common-workflows[{plugin_name}]`.'
) from exception
else: # noqa: PLW0120
raise
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from aiida import orm
from aiida.common import exceptions
from aiida.engine import calcfunction
from aiida_abinit.workflows.base import AbinitBaseWorkChain
from aiida.plugins import WorkflowFactory

from ..workchain import CommonRelaxWorkChain
from .generator import AbinitCommonRelaxInputGenerator
Expand Down Expand Up @@ -44,7 +44,7 @@ def get_total_magnetization(parameters):
class AbinitCommonRelaxWorkChain(CommonRelaxWorkChain):
"""Implementation of `aiida_common_workflows.common.relax.workchain.CommonRelaxWorkChain` for Abinit."""

_process_class = AbinitBaseWorkChain
_process_class = WorkflowFactory('abinit.base')
_generator_class = AbinitCommonRelaxInputGenerator

def convert_outputs(self):
Expand Down
17 changes: 9 additions & 8 deletions src/aiida_common_workflows/workflows/relax/castep/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import yaml
from aiida import engine, orm, plugins
from aiida.common import exceptions
from aiida_castep.data import get_pseudos_from_structure
from aiida_castep.data.otfg import OTFGGroup

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, CodeType

from ..generator import CommonRelaxInputGenerator

if t.TYPE_CHECKING:
from aiida_castep.data.otfg import OTFGGroup

KNOWN_BUILTIN_FAMILIES = ('C19', 'NCP19', 'QC5', 'C17', 'C9')

__all__ = ('CastepCommonRelaxInputGenerator',)
Expand Down Expand Up @@ -247,8 +248,8 @@ def generate_inputs(
:param override: a dictionary to override specific inputs
:return: input dictionary
"""

from aiida.common.lang import type_check
from aiida_castep.data.otfg import OTFGGroup

family_name = protocol['relax']['base']['pseudos_family']
if isinstance(family_name, orm.Str):
Expand Down Expand Up @@ -285,7 +286,7 @@ def generate_inputs_relax(
protocol: t.Dict,
code: orm.Code,
structure: orm.StructureData,
otfg_family: OTFGGroup,
otfg_family: 'OTFGGroup',
override: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict[str, t.Any]:
"""Generate the inputs for the `CastepCommonRelaxWorkChain` for a given code, structure and pseudo potential family.
Expand Down Expand Up @@ -321,7 +322,7 @@ def generate_inputs_base(
protocol: t.Dict,
code: orm.Code,
structure: orm.StructureData,
otfg_family: OTFGGroup,
otfg_family: 'OTFGGroup',
override: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict[str, t.Any]:
"""Generate the inputs for the `CastepBaseWorkChain` for a given code, structure and pseudo potential family.
Expand Down Expand Up @@ -359,7 +360,7 @@ def generate_inputs_calculation(
protocol: t.Dict,
code: orm.Code,
structure: orm.StructureData,
otfg_family: OTFGGroup,
otfg_family: 'OTFGGroup',
override: t.Optional[t.Dict[str, t.Any]] = None,
) -> t.Dict[str, t.Any]:
"""Generate the inputs for the `CastepCalculation` for a given code, structure and pseudo potential family.
Expand All @@ -372,6 +373,7 @@ def generate_inputs_calculation(
:return: the fully defined input dictionary.
"""
from aiida_castep.calculations.helper import CastepHelper
from aiida_castep.data import get_pseudos_from_structure

override = {} if not override else override.get('calc', {})
# This merge perserves the merged `parameters` in the override
Expand Down Expand Up @@ -415,9 +417,8 @@ def ensure_otfg_family(family_name, force_update=False):
NOTE: CASTEP also supports UPF families, but it is not enabled here, since no UPS based protocol
has been implemented.
"""

from aiida.common import NotExistent
from aiida_castep.data.otfg import upload_otfg_family
from aiida_castep.data.otfg import OTFGGroup, upload_otfg_family

# Ensure family name is a str
if isinstance(family_name, orm.Str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import yaml
from aiida import engine, orm, plugins
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge

from aiida_common_workflows.common import ElectronicType, RelaxType, SpinType
from aiida_common_workflows.generators import ChoiceType, CodeType
Expand Down Expand Up @@ -108,8 +107,8 @@ def _construct_builder(self, **kwargs) -> engine.ProcessBuilder: # noqa: PLR091
The keyword arguments will have been validated against the input generator specification.
"""

from aiida_quantumespresso.common import types
from aiida_quantumespresso.workflows.protocols.utils import recursive_merge
from qe_tools import CONSTANTS

structure = kwargs['structure']
Expand Down
28 changes: 28 additions & 0 deletions tests/test_minimal_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
installed. This guarantees that most of the code can be imported without any plugin packages being installed.
"""
import pytest
from aiida.common import exceptions
from aiida_common_workflows.plugins import WorkflowFactory, get_workflow_entry_point_names


@pytest.mark.minimal_install
Expand All @@ -18,3 +20,29 @@ def test_imports():
import aiida_common_workflows.workflows
import aiida_common_workflows.workflows.dissociation
import aiida_common_workflows.workflows.eos # noqa: F401


@pytest.mark.minimal_install
@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('relax'))
def test_workflow_factory_relax(entry_point_name):
"""Test that trying to load common relax workflow implementations will raise if not installed.
The exception message should provide the pip command to install the require plugin package.
"""
plugin_name = entry_point_name.removeprefix('common_workflows.relax.')
match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*'
with pytest.raises(exceptions.MissingEntryPointError, match=match):
WorkflowFactory(entry_point_name)


@pytest.mark.minimal_install
@pytest.mark.parametrize('entry_point_name', get_workflow_entry_point_names('bands'))
def test_workflow_factory_bands(entry_point_name):
"""Test that trying to load common bands workflow implementations will raise if not installed.
The exception message should provide the pip command to install the require plugin package.
"""
plugin_name = entry_point_name.removeprefix('common_workflows.bands.')
match = rf'.*plugin package is not installed.*`pip install aiida-common-workflows\[{plugin_name}\]`.*'
with pytest.raises(exceptions.MissingEntryPointError, match=match):
WorkflowFactory(entry_point_name)

0 comments on commit 340dbee

Please sign in to comment.