Skip to content

Commit

Permalink
Forming Communicator protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 14, 2024
1 parent 16ed57d commit 8ac05b0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
21 changes: 21 additions & 0 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Callable, Protocol

RpcSubscriber = Callable[['Communicator', Any], Any]
BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any]

class Communicator(Protocol):

def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any:
...

def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any:
...

def remove_rpc_subscriber(self, identifier):
...

def remove_broadcast_subscriber(self, identifier):
...

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool:
...
43 changes: 28 additions & 15 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import abc
import asyncio
import concurrent.futures
import contextlib
import copy
import enum
Expand Down Expand Up @@ -30,6 +31,8 @@
cast,
)

from plumpy.coordinator import Communicator

try:
from aiocontextvars import ContextVar
except ModuleNotFoundError:
Expand Down Expand Up @@ -252,7 +255,7 @@ def __init__(
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[kiwipy.Communicator] = None,
communicator: Optional[Communicator] = None,
) -> None:
"""
The signature of the constructor should not be changed by subclassing processes.
Expand Down Expand Up @@ -305,15 +308,16 @@ def init(self) -> None:
try:
identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier))
except kiwipy.TimeoutError:
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid)

try:
# filter out state change broadcasts
# TODO: pattern filter should be moved to add_broadcast_subscriber.
subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*'))
identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier))
except kiwipy.TimeoutError:
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid)

if not self._future.done():
Expand Down Expand Up @@ -697,8 +701,6 @@ def on_entering(self, state: process_states.State) -> None:
call_with_super_check(self.on_except, state.get_exc_info()) # type: ignore

def on_entered(self, from_state: Optional[process_states.State]) -> None:
from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed

# Map these onto direct functions that the subclass can implement
state_label = self._state.LABEL
if state_label == process_states.ProcessState.RUNNING:
Expand All @@ -713,6 +715,8 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None:
call_with_super_check(self.on_killed)

if self._communicator and isinstance(self.state, enum.Enum):
from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed

from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None
subject = f'state_changed.{from_label}.{self.state.value}'
self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject)
Expand All @@ -721,7 +725,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None:
except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError):
message = 'Process<%s>: no connection available to broadcast state change from %s to %s'
self.logger.warning(message, self.pid, from_label, self.state.value)
except kiwipy.TimeoutError:
except concurrent.futures.TimeoutError:
message = 'Process<%s>: sending broadcast of state change from %s to %s timed out'
self.logger.warning(message, self.pid, from_label, self.state.value)

Expand Down Expand Up @@ -900,7 +904,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -927,7 +931,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
) -> Optional[kiwipy.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand All @@ -939,15 +943,24 @@ def broadcast_receive(
self.logger.debug(
"Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body
)

# If we get a message we recognise then action it, otherwise ignore
fn = None
if subject == message.Intent.PLAY:
return self._schedule_rpc(self.play)
if subject == message.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
if subject == message.Intent.KILL:
return self._schedule_rpc(self.kill, msg=body)
return None
fn = self._schedule_rpc(self.play)
elif subject == message.Intent.PAUSE:
fn = self._schedule_rpc(self.pause, msg=body)
elif subject == message.Intent.KILL:
fn = self._schedule_rpc(self.kill, msg=body)

if fn is None:
self.logger.warning(
"Process<%s>: received unsupported broadcast message '%s'.",
self.pid,
subject,
)
return None

return fn

def _schedule_rpc(self, callback: Callable[..., Any], *args: Any, **kwargs: Any) -> kiwipy.Future:
"""
Expand Down

0 comments on commit 8ac05b0

Please sign in to comment.