Skip to content

Commit

Permalink
Interface change from communicator -> coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 17, 2024
1 parent c0a8bbd commit 3b8104b
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
13 changes: 13 additions & 0 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ def remove_rpc_subscriber(self, identifier): ...
def remove_broadcast_subscriber(self, identifier): ...

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ...

class Coordinator(Protocol):
def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ...

def add_broadcast_subscriber(
self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None
) -> Any: ...

def remove_rpc_subscriber(self, identifier): ...

def remove_broadcast_subscriber(self, identifier): ...

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ...
32 changes: 16 additions & 16 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
cast,
)

from plumpy.coordinator import Communicator
from plumpy.coordinator import Coordinator

try:
from aiocontextvars import ContextVar
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[Communicator] = None,
coordinator: Optional[Coordinator] = None,
) -> None:
"""
The signature of the constructor should not be changed by subclassing processes.
Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
self._future = persistence.SavableFuture(loop=self._loop)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator
self._coordinator = coordinator

@super_check
def init(self) -> None:
Expand All @@ -315,19 +315,19 @@ def init(self) -> None:
"""
self._cleanups = [] # a list of functions to be ran on terminated

if self._communicator is not None:
if self._coordinator is not None:
try:
identifier = self._communicator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._communicator.remove_rpc_subscriber, identifier))
identifier = self._coordinator.add_rpc_subscriber(self.message_receive, identifier=str(self.pid))
self.add_cleanup(functools.partial(self._coordinator.remove_rpc_subscriber, identifier))
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as an RPC subscriber', self.pid)

try:
# filter out state change broadcasts
identifier = self._communicator.add_broadcast_subscriber(
identifier = self._coordinator.add_broadcast_subscriber(
self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid)
)
self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier))
self.add_cleanup(functools.partial(self._coordinator.remove_broadcast_subscriber, identifier))
except concurrent.futures.TimeoutError:
self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid)

Expand Down Expand Up @@ -449,7 +449,7 @@ def launch(
pid=pid,
logger=logger,
loop=self.loop,
communicator=self._communicator,
coordinator=self._coordinator,
)
self.loop.create_task(process.step_until_terminated())
return process
Expand Down Expand Up @@ -645,7 +645,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi
self._future = persistence.SavableFuture()
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None
self._coordinator = None

if 'loop' in load_context:
self._loop = load_context.loop
Expand All @@ -654,8 +654,8 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

self._state: process_states.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
self._communicator = load_context.communicator
if 'coordinator' in load_context:
self._coordinator = load_context.coordinator

if 'logger' in load_context:
self._logger = load_context.logger
Expand Down Expand Up @@ -740,15 +740,15 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None:
elif state_label == process_states.ProcessState.KILLED:
call_with_super_check(self.on_killed)

if self._communicator and isinstance(self.state, enum.Enum):
if self._coordinator and isinstance(self.state, enum.Enum):
# FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator
from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed

from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None
subject = f'state_changed.{from_label}.{self.state.value}'
self.logger.info('Process<%s>: Broadcasting state change: %s', self.pid, subject)
try:
self._communicator.broadcast_send(body=None, sender=self.pid, subject=subject)
self._coordinator.broadcast_send(body=None, sender=self.pid, subject=subject)
except (CommunicatorConnectionClosed, CommunicatorChannelInvalidStateError):
message = 'Process<%s>: no connection available to broadcast state change from %s to %s'
self.logger.warning(message, self.pid, from_label, self.state.value)
Expand Down Expand Up @@ -938,7 +938,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand Down Expand Up @@ -970,7 +970,7 @@ def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any:
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any
) -> Optional[concurrent.futures.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand Down
6 changes: 3 additions & 3 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
)

from plumpy.coordinator import Communicator
from plumpy.coordinator import Coordinator

from . import lang, mixins, persistence, process_states, processes
from .utils import PID_TYPE, SAVED_STATE_TYPE
Expand Down Expand Up @@ -128,9 +128,9 @@ def __init__(
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[Communicator] = None,
coordinator: Optional[Coordinator] = None,
) -> None:
super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator)
super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator)
self._stepper: Optional[Stepper] = None
self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {}

Expand Down
24 changes: 12 additions & 12 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator):
class TestRemoteProcessController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
# Send a pause message
Expand All @@ -57,7 +57,7 @@ async def test_pause(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_play(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
assert proc.pause()
Expand All @@ -75,7 +75,7 @@ async def test_play(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the event loop
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -88,7 +88,7 @@ async def test_kill(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_status(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -108,7 +108,7 @@ def on_broadcast_receive(**msg):

thread_communicator.add_broadcast_subscriber(on_broadcast_receive)

proc = utils.DummyProcess(communicator=thread_communicator)
proc = utils.DummyProcess(coordinator=thread_communicator)
proc.execute()

expected_subjects = []
Expand All @@ -123,7 +123,7 @@ def on_broadcast_receive(**msg):
class TestRemoteProcessThreadController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)

# Send a pause message
pause_future = sync_controller.pause_process(proc.pid)
Expand All @@ -140,7 +140,7 @@ async def test_pause_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))

sync_controller.pause_all("Slow yo' roll")
# Wait until they are all paused
Expand All @@ -151,7 +151,7 @@ async def test_play_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
procs.append(proc)
proc.pause('hold tight')

Expand All @@ -162,7 +162,7 @@ async def test_play_all(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_play(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
assert proc.pause()

# Send a play message
Expand All @@ -176,7 +176,7 @@ async def test_play(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)

# Send a kill message
kill_future = sync_controller.kill_process(proc.pid)
Expand All @@ -193,7 +193,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

Expand All @@ -203,7 +203,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_status(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id):
messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id})

communicator.add_broadcast_subscriber(on_broadcast_receive)
proc = utils.DummyProcess(communicator=communicator)
proc = utils.DummyProcess(coordinator=communicator)
proc.execute()

expected_subjects = []
Expand Down

0 comments on commit 3b8104b

Please sign in to comment.