diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 1daaf1f8..cd66a883 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -17,3 +17,16 @@ def remove_rpc_subscriber(self, identifier): ... def remove_broadcast_subscriber(self, identifier): ... def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... + +class Coordinator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... + + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... + + def remove_rpc_subscriber(self, identifier): ... + + def remove_broadcast_subscriber(self, identifier): ... + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 9cf3302b..a947eaba 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -32,7 +32,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator try: from aiocontextvars import ContextVar @@ -266,7 +266,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -305,7 +305,7 @@ def __init__( self._future = persistence.SavableFuture(loop=self._loop) self._event_helper = EventHelper(ProcessListener) self._logger = logger - self._communicator = communicator + self._coordinator = coordinator @super_check def init(self) -> None: @@ -315,19 +315,19 @@ def init(self) -> None: """ self._cleanups = [] # a list of functions to be ran on terminated - if self._communicator is not None: + if self._coordinator is not None: try: - identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) - self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier)) + identifier = self._coordinator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid)) + self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid) try: # filter out state change broadcasts - identifier = self._communicator.add_broadcast_subscriber( + identifier = self._coordinator.add_broadcast_subscriber( self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) ) - self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) + self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -449,7 +449,7 @@ def launch( pid=pid, logger=logger, loop=self.loop, - communicator=self._communicator, + coordinator=self._coordinator, ) self.loop.create_task(process.step_until_terminated()) return process @@ -645,7 +645,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._future = persistence.SavableFuture() self._event_helper = EventHelper(ProcessListener) self._logger = None - self._communicator = None + self._coordinator = None if 'loop' in load_context: self._loop = load_context.loop @@ -654,8 +654,8 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self._state: process_states.State = self.recreate_state(saved_state['_state']) - if 'communicator' in load_context: - self._communicator = load_context.communicator + if 'coordinator' in load_context: + self._coordinator = load_context.coordinator if 'logger' in load_context: self._logger = load_context.logger @@ -740,7 +740,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: elif state_label == process_states.ProcessState.KILLED: call_with_super_check(self.on_killed) - if self._communicator and isinstance(self.state, enum.Enum): + if self._coordinator and isinstance(self.state, enum.Enum): # FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed @@ -748,7 +748,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: subject = f'state_changed.{from_label}.{self.state.value}' self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject) try: - self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject) + self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject) except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError): message = 'Process<%s>: no connection available to broadcast state change from %s to %s' self.logger.warning(message, self.pid, from_label, self.state.value) @@ -938,7 +938,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non # region Communication - def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any: + def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any: """ Coroutine called when the process receives a message from the communicator @@ -970,7 +970,7 @@ def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any: raise RuntimeError('Unknown intent') def broadcast_receive( - self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any ) -> Optional[concurrent.futures.Future]: """ Coroutine called when the process receives a message from the communicator diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 7e67253f..5df20bf4 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -23,7 +23,7 @@ cast, ) -from plumpy.coordinator import Communicator +from plumpy.coordinator import Coordinator from . import lang, mixins, persistence, process_states, processes from .utils import PID_TYPE, SAVED_STATE_TYPE @@ -128,9 +128,9 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[Communicator] = None, + coordinator: Optional[Coordinator] = None, ) -> None: - super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) + super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator) self._stepper: Optional[Stepper] = None self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {} diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 4af9a484..9de211ee 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -45,7 +45,7 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -57,7 +57,7 @@ async def test_pause(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -75,7 +75,7 @@ async def test_play(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -88,7 +88,7 @@ async def test_kill(self, thread_communicator, async_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -108,7 +108,7 @@ def on_broadcast_receive(**msg): thread_communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=thread_communicator) + proc = utils.DummyProcess(coordinator=thread_communicator) proc.execute() expected_subjects = [] @@ -123,7 +123,7 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -140,7 +140,7 @@ async def test_pause_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused @@ -151,7 +151,7 @@ async def test_play_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) procs.append(proc) proc.pause('hold tight') @@ -162,7 +162,7 @@ async def test_play_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) assert proc.pause() # Send a play message @@ -176,7 +176,7 @@ async def test_play(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -193,7 +193,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(communicator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') @@ -203,7 +203,7 @@ async def test_kill_all(self, thread_communicator, sync_controller): @pytest.mark.asyncio async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(communicator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_communicator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) diff --git a/tests/test_processes.py b/tests/test_processes.py index 7b232689..927d42b3 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1075,7 +1075,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(communicator=communicator) + proc = utils.DummyProcess(coordinator=communicator) proc.execute() expected_subjects = []