forked from aiidateam/plumpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
365 additions
and
340 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Module for process level communication functions and classes""" | ||
|
||
import asyncio | ||
import logging | ||
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast | ||
|
||
import kiwipy | ||
|
||
from . import loaders, persistence | ||
from .utils import PID_TYPE | ||
|
||
__all__ = [ | ||
'KILL_MSG', | ||
'PAUSE_MSG', | ||
'PLAY_MSG', | ||
'STATUS_MSG', | ||
'ProcessLauncher', | ||
'create_continue_body', | ||
'create_launch_body', | ||
] | ||
|
||
if TYPE_CHECKING: | ||
from .processes import Process | ||
|
||
INTENT_KEY = 'intent' | ||
MESSAGE_KEY = 'message' | ||
|
||
|
||
class Intent: | ||
"""Intent constants for a process message""" | ||
|
||
PLAY: str = 'play' | ||
PAUSE: str = 'pause' | ||
KILL: str = 'kill' | ||
STATUS: str = 'status' | ||
|
||
|
||
PAUSE_MSG = {INTENT_KEY: Intent.PAUSE} | ||
PLAY_MSG = {INTENT_KEY: Intent.PLAY} | ||
KILL_MSG = {INTENT_KEY: Intent.KILL} | ||
STATUS_MSG = {INTENT_KEY: Intent.STATUS} | ||
|
||
TASK_KEY = 'task' | ||
TASK_ARGS = 'args' | ||
PERSIST_KEY = 'persist' | ||
# Launch | ||
PROCESS_CLASS_KEY = 'process_class' | ||
ARGS_KEY = 'init_args' | ||
KWARGS_KEY = 'init_kwargs' | ||
NOWAIT_KEY = 'nowait' | ||
# Continue | ||
PID_KEY = 'pid' | ||
TAG_KEY = 'tag' | ||
# Task types | ||
LAUNCH_TASK = 'launch' | ||
CONTINUE_TASK = 'continue' | ||
CREATE_TASK = 'create' | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def create_launch_body( | ||
process_class: str, | ||
init_args: Optional[Sequence[Any]] = None, | ||
init_kwargs: Optional[Dict[str, Any]] = None, | ||
persist: bool = False, | ||
loader: Optional[loaders.ObjectLoader] = None, | ||
nowait: bool = True, | ||
) -> Dict[str, Any]: | ||
""" | ||
Create a message body for the launch action | ||
:param process_class: the class of the process to launch | ||
:param init_args: any initialisation positional arguments | ||
:param init_kwargs: any initialisation keyword arguments | ||
:param persist: persist this process if True, otherwise don't | ||
:param loader: the loader to use to load the persisted process | ||
:param nowait: wait for the process to finish before completing the task, otherwise just return the PID | ||
:return: a dictionary with the body of the message to launch the process | ||
:rtype: dict | ||
""" | ||
if loader is None: | ||
loader = loaders.get_object_loader() | ||
|
||
msg_body = { | ||
TASK_KEY: LAUNCH_TASK, | ||
TASK_ARGS: { | ||
PROCESS_CLASS_KEY: loader.identify_object(process_class), | ||
PERSIST_KEY: persist, | ||
NOWAIT_KEY: nowait, | ||
ARGS_KEY: init_args, | ||
KWARGS_KEY: init_kwargs, | ||
}, | ||
} | ||
return msg_body | ||
|
||
|
||
def create_continue_body(pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False) -> Dict[str, Any]: | ||
""" | ||
Create a message body to continue an existing process | ||
:param pid: the pid of the existing process | ||
:param tag: the optional persistence tag | ||
:param nowait: wait for the process to finish before completing the task, otherwise just return the PID | ||
:return: a dictionary with the body of the message to continue the process | ||
""" | ||
msg_body = {TASK_KEY: CONTINUE_TASK, TASK_ARGS: {PID_KEY: pid, NOWAIT_KEY: nowait, TAG_KEY: tag}} | ||
return msg_body | ||
|
||
|
||
def create_create_body( | ||
process_class: str, | ||
init_args: Optional[Sequence[Any]] = None, | ||
init_kwargs: Optional[Dict[str, Any]] = None, | ||
persist: bool = False, | ||
loader: Optional[loaders.ObjectLoader] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Create a message body to create a new process | ||
:param process_class: the class of the process to launch | ||
:param init_args: any initialisation positional arguments | ||
:param init_kwargs: any initialisation keyword arguments | ||
:param persist: persist this process if True, otherwise don't | ||
:param loader: the loader to use to load the persisted process | ||
:return: a dictionary with the body of the message to launch the process | ||
""" | ||
if loader is None: | ||
loader = loaders.get_object_loader() | ||
|
||
msg_body = { | ||
TASK_KEY: CREATE_TASK, | ||
TASK_ARGS: { | ||
PROCESS_CLASS_KEY: loader.identify_object(process_class), | ||
PERSIST_KEY: persist, | ||
ARGS_KEY: init_args, | ||
KWARGS_KEY: init_kwargs, | ||
}, | ||
} | ||
return msg_body | ||
|
||
|
||
class ProcessLauncher: | ||
""" | ||
Takes incoming task messages and uses them to launch processes. | ||
Expected format of task: | ||
For launch:: | ||
{ | ||
'task': <LAUNCH_TASK> | ||
'process_class': <Process class to launch> | ||
'args': <tuple of positional args for process constructor> | ||
'kwargs': <dict of keyword args for process constructor>. | ||
'nowait': True or False | ||
} | ||
For continue:: | ||
{ | ||
'task': <CONTINUE_TASK> | ||
'pid': <Process ID> | ||
'nowait': True or False | ||
} | ||
""" | ||
|
||
def __init__( | ||
self, | ||
loop: Optional[asyncio.AbstractEventLoop] = None, | ||
persister: Optional[persistence.Persister] = None, | ||
load_context: Optional[persistence.LoadSaveContext] = None, | ||
loader: Optional[loaders.ObjectLoader] = None, | ||
) -> None: | ||
self._loop = loop | ||
self._persister = persister | ||
self._load_context = load_context if load_context is not None else persistence.LoadSaveContext() | ||
|
||
if loader is not None: | ||
self._loader = loader | ||
self._load_context = self._load_context.copyextend(loader=loader) | ||
else: | ||
self._loader = loaders.get_object_loader() | ||
|
||
async def __call__(self, communicator: kiwipy.Communicator, task: Dict[str, Any]) -> Union[PID_TYPE, Any]: | ||
""" | ||
Receive a task. | ||
:param task: The task message | ||
""" | ||
from plumpy.rmq import communications | ||
|
||
task_type = task[TASK_KEY] | ||
if task_type == LAUNCH_TASK: | ||
return await self._launch(communicator, **task.get(TASK_ARGS, {})) | ||
if task_type == CONTINUE_TASK: | ||
return await self._continue(communicator, **task.get(TASK_ARGS, {})) | ||
if task_type == CREATE_TASK: | ||
return await self._create(communicator, **task.get(TASK_ARGS, {})) | ||
|
||
raise communications.TaskRejected | ||
|
||
async def _launch( | ||
self, | ||
_communicator: kiwipy.Communicator, | ||
process_class: str, | ||
persist: bool, | ||
nowait: bool, | ||
init_args: Optional[Sequence[Any]] = None, | ||
init_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> Union[PID_TYPE, Any]: | ||
""" | ||
Launch the process | ||
:param _communicator: the communicator | ||
:param process_class: the process class to launch | ||
:param persist: should the process be persisted | ||
:param nowait: if True only return when the process finishes | ||
:param init_args: positional arguments to the process constructor | ||
:param init_kwargs: keyword arguments to the process constructor | ||
:return: the pid of the created process or the outputs (if nowait=False) | ||
""" | ||
from plumpy.rmq import communications | ||
|
||
if persist and not self._persister: | ||
raise communications.TaskRejected('Cannot persist process, no persister') | ||
|
||
if init_args is None: | ||
init_args = () | ||
if init_kwargs is None: | ||
init_kwargs = {} | ||
|
||
proc_class = self._loader.load_object(process_class) | ||
proc = proc_class(*init_args, **init_kwargs) | ||
if persist and self._persister is not None: | ||
self._persister.save_checkpoint(proc) | ||
|
||
if nowait: | ||
# XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails | ||
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 | ||
return proc.pid | ||
|
||
await proc.step_until_terminated() | ||
|
||
return proc.future().result() | ||
|
||
async def _continue( | ||
self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None | ||
) -> Union[PID_TYPE, Any]: | ||
""" | ||
Continue the process | ||
:param _communicator: the communicator | ||
:param pid: the pid of the process to continue | ||
:param nowait: if True don't wait for the process to complete | ||
:param tag: the checkpoint tag to continue from | ||
""" | ||
from plumpy.rmq import communications | ||
|
||
if not self._persister: | ||
LOGGER.warning('rejecting task: cannot continue process<%d> because no persister is available', pid) | ||
raise communications.TaskRejected('Cannot continue process, no persister') | ||
|
||
# Do not catch exceptions here, because if these operations fail, the continue task should except and bubble up | ||
saved_state = self._persister.load_checkpoint(pid, tag) | ||
proc = cast('Process', saved_state.unbundle(self._load_context)) | ||
|
||
if nowait: | ||
# XXX: can return a reference and gracefully use task to cancel itself when the upper call stack fails | ||
asyncio.ensure_future(proc.step_until_terminated()) # noqa: RUF006 | ||
return proc.pid | ||
|
||
await proc.step_until_terminated() | ||
|
||
return proc.future().result() | ||
|
||
async def _create( | ||
self, | ||
_communicator: kiwipy.Communicator, | ||
process_class: str, | ||
persist: bool, | ||
init_args: Optional[Sequence[Any]] = None, | ||
init_kwargs: Optional[Dict[str, Any]] = None, | ||
) -> 'PID_TYPE': | ||
""" | ||
Create the process | ||
:param _communicator: the communicator | ||
:param process_class: the process class to create | ||
:param persist: should the process be persisted | ||
:param init_args: positional arguments to the process constructor | ||
:param init_kwargs: keyword arguments to the process constructor | ||
:return: the pid of the created process | ||
""" | ||
from plumpy.rmq import communications | ||
|
||
if persist and not self._persister: | ||
raise communications.TaskRejected('Cannot persist process, no persister') | ||
|
||
if init_args is None: | ||
init_args = () | ||
if init_kwargs is None: | ||
init_kwargs = {} | ||
|
||
proc_class = self._loader.load_object(process_class) | ||
proc = proc_class(*init_args, **init_kwargs) | ||
if persist and self._persister is not None: | ||
self._persister.save_checkpoint(proc) | ||
|
||
return proc.pid |
Oops, something went wrong.