Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add awaitable_builder decorator #239

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,23 @@ def decorator(func):

return decorator

@staticmethod
@nonfunctional_usage
def awaitable_builder(**kwargs: Any) -> Callable:
def decorator(func):
# Then, apply task decorator
task_decorated = build_task_from_callable(
func,
inputs=kwargs.get("inputs", []),
outputs=kwargs.get("outputs", []),
)
task_decorated.node_type = "awaitable_builder"
func.identifier = "awaitable_builder"
func.task = func.node = task_decorated
return func

return decorator

# Making decorator_task accessible as 'task'
task = decorator_task

Expand Down
76 changes: 76 additions & 0 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from __future__ import annotations
from aiida_workgraph.orm.serializer import serialize_to_aiida_nodes
from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiida.engine.utils import instantiate_process, prepare_inputs
from aiida.manage import manager
from aiida.engine import run_get_node
from aiida.common import InvalidOperation
from aiida.common.log import AIIDA_LOGGER
from aiida.engine.processes import Process, ProcessBuilder
from aiida.orm import ProcessNode
import typing as t
import time

TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder]
LOGGER = AIIDA_LOGGER.getChild("engine.launch")


def prepare_for_workgraph_task(task: dict, kwargs: dict) -> tuple:
Expand Down Expand Up @@ -139,3 +152,66 @@ def prepare_for_shell_task(task: dict, kwargs: dict) -> dict:
"metadata": metadata or {},
}
return inputs


# modified from aiida.engine.submit
# do not check the scope of the process
def submit(
process: TYPE_SUBMIT_PROCESS,
inputs: dict[str, t.Any] | None = None,
*,
wait: bool = False,
wait_interval: int = 5,
**kwargs: t.Any,
) -> ProcessNode:
"""Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter.

.. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of
the wrapping process itself, i.e. use ``self.submit``.

.. warning: submission of processes requires ``store_provenance=True``.

:param process: the process class, instance or builder to submit
:param inputs: the input dictionary to be passed to the process
:param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which
point the function returns the calculation node.
:param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``.
:param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument.
:return: the calculation node of the process
"""
inputs = prepare_inputs(inputs, **kwargs)

runner = manager.get_manager().get_runner()
assert runner.persister is not None, "runner does not have a persister"
assert runner.controller is not None, "runner does not have a controller"

process_inited = instantiate_process(runner, process, **inputs)

# If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this
# instead of raising, because in this way the user does not have to change the launcher when testing. The same goes
# for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation.
if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs:
_, node = run_get_node(process_inited)
return node

if not process_inited.metadata.store_provenance:
raise InvalidOperation("cannot submit a process with `store_provenance=False`")

runner.persister.save_checkpoint(process_inited)
process_inited.close()

# Do not wait for the future's result, because in the case of a single worker this would cock-block itself
runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True)
node = process_inited.node

if not wait:
return node

while not node.is_terminated:
LOGGER.report(
f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. "
f"Waiting for {wait_interval} seconds."
)
time.sleep(wait_interval)

return node
43 changes: 37 additions & 6 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
__all__ = "WorkGraph"


MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}. Cannot launch the job: {}."
MAX_NUMBER_AWAITABLES_MSG = "The maximum number of subprocesses has been reached: {}.\
Waiting for other jobs to finish before launching the {}."


@auto_persist("_awaitables")
Expand Down Expand Up @@ -304,16 +305,16 @@ def _do_step(self) -> t.Any:
else:
finished, result = self.is_workgraph_finished()

if self._awaitables:
return Wait(self._do_step, "Waiting before next step")

# If the workgraph is finished or the result is an ExitCode, we exit by returning
if finished:
if isinstance(result, ExitCode):
return result
else:
return self.finalize()

if self._awaitables:
return Wait(self._do_step, "Waiting before next step")

return Continue(self._do_step)

def _store_nodes(self, data: t.Any) -> None:
Expand Down Expand Up @@ -442,7 +443,10 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:

# node finished, update the task state and result
# udpate the task state
self.update_task_state(awaitable.key)
if awaitable.key in self.ctx._tasks:
self.update_task_state(awaitable.key)
else:
self.report(f"Awaitable {awaitable.key} finished.")
# try to resume the workgraph, if the workgraph is already resumed
# by other awaitable, this will not work
try:
Expand Down Expand Up @@ -956,9 +960,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
"WORKGRAPH",
"PYTHONJOB",
"SHELLJOB",
"AWAITABLE_BUILDER",
]:
if len(self._awaitables) >= self.ctx._max_number_awaitables:
print(
self.report(
MAX_NUMBER_AWAITABLES_MSG.format(
self.ctx._max_number_awaitables, name
)
Expand Down Expand Up @@ -1151,6 +1156,32 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["AWAITABLE_BUILDER"]:
# create the awaitable
for key in self.ctx._tasks[name]["metadata"]["args"]:
kwargs.pop(key, None)
results = self.run_executor(
executor, args, kwargs, var_args, var_kwargs
)
if not isinstance(results, dict):

self.logger.error(
"The results of the awaitable builder must be a dict."
)
for key, value in results.items():
if not isinstance(value, ProcessNode):
self.logger.error(
f"The value of key {key} is not an instance of ProcessNode."
)
self.set_task_state_info(name, "state", "Failed")
self.set_task_state_info(name, "state", "Failed")
self.report(f"Task: {name} failed.")
else:
self.set_task_state_info(name, "state", "FINISHED")
self.to_context(**results)
self.report(f"Task: {name} finished.")
self.update_parent_task_state(name)
self.continue_workgraph()
elif task["metadata"]["node_type"].upper() in ["AWAITABLE"]:
for key in self.ctx._tasks[name]["metadata"]["args"]:
kwargs.pop(key, None)
Expand Down
Loading