diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 750df1ba..e5718304 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -702,6 +702,23 @@ def decorator(func): return decorator + @staticmethod + @nonfunctional_usage + def awaitable_builder(**kwargs: Any) -> Callable: + def decorator(func): + # Then, apply task decorator + task_decorated = build_task_from_callable( + func, + inputs=kwargs.get("inputs", []), + outputs=kwargs.get("outputs", []), + ) + task_decorated.node_type = "awaitable_builder" + func.identifier = "awaitable_builder" + func.task = func.node = task_decorated + return func + + return decorator + # Making decorator_task accessible as 'task' task = decorator_task diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index 82332a78..b9e78e11 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -1,6 +1,19 @@ +from __future__ import annotations from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes from aiida import orm from aiida.common.extendeddicts import AttributeDict +from aiida.engine.utils import instantiate_process, prepare_inputs +from aiida.manage import manager +from aiida.engine import run_get_node +from aiida.common import InvalidOperation +from aiida.common.log import AIIDA_LOGGER +from aiida.engine.processes import Process, ProcessBuilder +from aiida.orm import ProcessNode +import typing as t +import time + +TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +LOGGER = AIIDA_LOGGER.getChild("engine.launch") def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple: @@ -139,3 +152,66 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict: "metadata": metadata or {}, } return inputs + + +# modified from aiida.engine.submit +# do not check the scope of the process +def submit( + process: TYPE_SUBMIT_PROCESS, + inputs: dict[str, t.Any] | None = None, + *, + wait: bool = False, + wait_interval: int = 5, + **kwargs: t.Any, +) -> ProcessNode: + """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. + + .. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of + the wrapping process itself, i.e. use ``self.submit``. + + .. warning: submission of processes requires ``store_provenance=True``. + + :param process: the process class, instance or builder to submit + :param inputs: the input dictionary to be passed to the process + :param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which + point the function returns the calculation node. + :param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``. + :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument. + :return: the calculation node of the process + """ + inputs = prepare_inputs(inputs, **kwargs) + + runner = manager.get_manager().get_runner() + assert runner.persister is not None, "runner does not have a persister" + assert runner.controller is not None, "runner does not have a controller" + + process_inited = instantiate_process(runner, process, **inputs) + + # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this + # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes + # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. + if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs: + _, node = run_get_node(process_inited) + return node + + if not process_inited.metadata.store_provenance: + raise InvalidOperation("cannot submit a process with `store_provenance=False`") + + runner.persister.save_checkpoint(process_inited) + process_inited.close() + + # Do not wait for the future's result, because in the case of a single worker this would cock-block itself + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) + node = process_inited.node + + if not wait: + return node + + while not node.is_terminated: + LOGGER.report( + f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. " + f"Waiting for {wait_interval} seconds." + ) + time.sleep(wait_interval) + + return node diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index dfac6d92..d192e3e5 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -42,7 +42,8 @@ __all__ = "WorkGraph" -MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}." +MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}.\ +Waiting for other jobs to finish before launching the {}." @auto_persist("_awaitables") @@ -304,6 +305,9 @@ def _do_step(self) -> t.Any: else: finished, result = self.is_workgraph_finished() + if self._awaitables: + return Wait(self._do_step, "Waiting before next step") + # If the workgraph is finished or the result is an ExitCode, we exit by returning if finished: if isinstance(result, ExitCode): @@ -311,9 +315,6 @@ def _do_step(self) -> t.Any: else: return self.finalize() - if self._awaitables: - return Wait(self._do_step, "Waiting before next step") - return Continue(self._do_step) def _store_nodes(self, data: t.Any) -> None: @@ -442,7 +443,10 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: # node finished, update the task state and result # udpate the task state - self.update_task_state(awaitable.key) + if awaitable.key in self.ctx._tasks: + self.update_task_state(awaitable.key) + else: + self.report(f"Awaitable {awaitable.key} finished.") # try to resume the workgraph, if the workgraph is already resumed # by other awaitable, this will not work try: @@ -956,9 +960,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None "WORKGRAPH", "PYTHONJOB", "SHELLJOB", + "AWAITABLE_BUILDER", ]: if len(self._awaitables) >= self.ctx._max_number_awaitables: - print( + self.report( MAX_NUMBER_AWAITABLES_MSG.format( self.ctx._max_number_awaitables, name ) @@ -1151,6 +1156,32 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None self.set_task_state_info(name, "state", "FINISHED") self.update_parent_task_state(name) self.continue_workgraph() + elif task["metadata"]["node_type"].upper() in ["AWAITABLE_BUILDER"]: + # create the awaitable + for key in self.ctx._tasks[name]["metadata"]["args"]: + kwargs.pop(key, None) + results = self.run_executor( + executor, args, kwargs, var_args, var_kwargs + ) + if not isinstance(results, dict): + + self.logger.error( + "The results of the awaitable builder must be a dict." + ) + for key, value in results.items(): + if not isinstance(value, ProcessNode): + self.logger.error( + f"The value of key {key} is not an instance of ProcessNode." + ) + self.set_task_state_info(name, "state", "Failed") + self.set_task_state_info(name, "state", "Failed") + self.report(f"Task: {name} failed.") + else: + self.set_task_state_info(name, "state", "FINISHED") + self.to_context(**results) + self.report(f"Task: {name} finished.") + self.update_parent_task_state(name) + self.continue_workgraph() elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]: for key in self.ctx._tasks[name]["metadata"]["args"]: kwargs.pop(key, None)