diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py new file mode 100644 index 00000000..214fc18f --- /dev/null +++ b/src/plumpy/coordinator.py @@ -0,0 +1,21 @@ +from typing import Any, Callable, Protocol + +RpcSubscriber = Callable[['Communicator', Any], Any] +BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] + +class Communicator(Protocol): + + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: + ... + + def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any: + ... + + def remove_rpc_subscriber(self, identifier): + ... + + def remove_broadcast_subscriber(self, identifier): + ... + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: + ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 091fd05d..93c52767 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -3,6 +3,7 @@ import abc import asyncio +import concurrent.futures import contextlib import copy import enum @@ -30,6 +31,8 @@ cast, ) +from plumpy.coordinator import Communicator + try: from aiocontextvars import ContextVar except ModuleNotFoundError: @@ -252,7 +255,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None, + communicator: Optional[Communicator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -305,15 +308,16 @@ def init(self) -> None: try: identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) - except kiwipy.TimeoutError: + except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts + # TODO: pattern filter should be moved to add_broadcast_subscriber. subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) - except kiwipy.TimeoutError: + except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) if not self._future.done(): @@ -697,8 +701,6 @@ def on_entering(self, state: process_states.State) -> None: call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore def on_entered(self, from_state: Optional[process_states.State]) -> None: - from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed - # Map these onto direct functions that the subclass can implement state_label = self._state.LABEL if state_label == process_states.ProcessState.RUNNING: @@ -713,6 +715,8 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): + from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed + from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) @@ -721,7 +725,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) - except kiwipy.TimeoutError: + except concurrent.futures.TimeoutError: message = 'Process<%s>: sending broadcast of state change from %s to %s timed out' self.logger.warning(message, self.pid, from_label, self.state.value) @@ -900,7 +904,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -927,7 +931,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator @@ -939,15 +943,24 @@ def broadcast_receive( self.logger.debug( "Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body ) - # If we get a message we recognise then action it, otherwise ignore + fn = None if subject == message.Intent.PLAY: - return self._schedule_rpc(self.play) - if subject == message.Intent.PAUSE: - return self._schedule_rpc(self.pause, msg=body) - if subject == message.Intent.KILL: - return self._schedule_rpc(self.kill, msg=body) - return None + fn = self._schedule_rpc(self.play) + elif subject == message.Intent.PAUSE: + fn = self._schedule_rpc(self.pause, msg=body) + elif subject == message.Intent.KILL: + fn = self._schedule_rpc(self.kill, msg=body) + + if fn is None: + self.logger.warning( + "Process<%s>: received unsupported broadcast message '%s'.", + self.pid, + subject, + ) + return None + + return fn def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future: """