diff --git a/src/plumpy/coordinator.py b/src/plumpy/coordinator.py index 214fc18f..1daaf1f8 100644 --- a/src/plumpy/coordinator.py +++ b/src/plumpy/coordinator.py @@ -1,21 +1,19 @@ -from typing import Any, Callable, Protocol +# -*- coding: utf-8 -*- +from typing import Any, Callable, Pattern, Protocol RpcSubscriber = Callable[['Communicator', Any], Any] BroadcastSubscriber = Callable[['Communicator', Any, Any, Any, Any], Any] -class Communicator(Protocol): - def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: - ... +class Communicator(Protocol): + def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ... - def add_broadcast_subscriber(self, subscriber: BroadcastSubscriber, identifier=None) -> Any: - ... + def add_broadcast_subscriber( + self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None + ) -> Any: ... - def remove_rpc_subscriber(self, identifier): - ... + def remove_rpc_subscriber(self, identifier): ... - def remove_broadcast_subscriber(self, identifier): - ... + def remove_broadcast_subscriber(self, identifier): ... - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: - ... + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ... diff --git a/src/plumpy/exceptions.py b/src/plumpy/exceptions.py index 1e6f3b26..9dca8fdb 100644 --- a/src/plumpy/exceptions.py +++ b/src/plumpy/exceptions.py @@ -39,5 +39,6 @@ class PersistenceError(Exception): class ClosedError(Exception): """Raised when an mutable operation is attempted on a closed process""" + class TaskRejectedError(Exception): - """ A task was rejected by the coordinacor""" + """A task was rejected by the coordinacor""" diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 2f861d64..01be3951 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -7,7 +7,7 @@ import contextlib from typing import Any, Awaitable, Callable, Generator, Optional -__all__ = ['CancellableAction', 'create_task', 'create_task', 'capture_exceptions'] +__all__ = ['CancellableAction', 'capture_exceptions', 'create_task', 'create_task'] class InvalidFutureError(Exception): diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 02dd123b..6e847f22 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -38,15 +38,14 @@ except ModuleNotFoundError: from contextvars import ContextVar -import kiwipy import yaml from . import events, exceptions, message, persistence, ports, process_states, utils -from .futures import capture_exceptions, CancellableAction from .base import state_machine from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event from .base.utils import call_with_super_check, super_check from .event_helper import EventHelper +from .futures import CancellableAction, capture_exceptions from .process_listener import ProcessListener from .process_spec import ProcessSpec from .utils import PID_TYPE, SAVED_STATE_TYPE, protected @@ -313,9 +312,9 @@ def init(self) -> None: try: # filter out state change broadcasts - # TODO: pattern filter should be moved to add_broadcast_subscriber. - subscriber = kiwipy.BroadcastFilter(self.broadcast_receive, subject=re.compile(r'^(?!state_changed).*')) - identifier = self._communicator.add_broadcast_subscriber(subscriber, identifier=str(self.pid)) + identifier = self._communicator.add_broadcast_subscriber( + self.broadcast_receive, subject_filter=re.compile(r'^(?!state_changed).*'), identifier=str(self.pid) + ) self.add_cleanup(functools.partial(self._communicator.remove_broadcast_subscriber, identifier)) except concurrent.futures.TimeoutError: self.logger.exception('Process<%s>: failed to register as a broadcast subscriber', self.pid) @@ -715,6 +714,7 @@ def on_entered(self, from_state: Optional[process_states.State]) -> None: call_with_super_check(self.on_killed) if self._communicator and isinstance(self.state, enum.Enum): + # FIXME: move all to `coordinator.broadcast()` call and in rmq implement coordinator from plumpy.rmq.exceptions import CommunicatorChannelInvalidStateError, CommunicatorConnectionClosed from_label = cast(enum.Enum, from_state.LABEL).value if from_state is not None else None diff --git a/src/plumpy/rmq/communications.py b/src/plumpy/rmq/communications.py index 9dbafbed..6d1f337c 100644 --- a/src/plumpy/rmq/communications.py +++ b/src/plumpy/rmq/communications.py @@ -131,10 +131,10 @@ def remove_task_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_task_subscriber(identifier) def add_broadcast_subscriber( - self, subscriber: 'BroadcastSubscriber', identifier: Optional['ID_TYPE'] = None + self, subscriber: 'BroadcastSubscriber', subject_filter=None, identifier: Optional['ID_TYPE'] = None ) -> 'ID_TYPE': converted = convert_to_comm(subscriber, self._loop) - return self._communicator.add_broadcast_subscriber(converted, identifier) + return self._communicator.add_broadcast_subscriber(converted, subject_filter, identifier) def remove_broadcast_subscriber(self, identifier: 'ID_TYPE') -> None: return self._communicator.remove_broadcast_subscriber(identifier) diff --git a/src/plumpy/rmq/process_comms.py b/src/plumpy/rmq/process_comms.py index 75db4e6e..f52189a7 100644 --- a/src/plumpy/rmq/process_comms.py +++ b/src/plumpy/rmq/process_comms.py @@ -3,23 +3,22 @@ import asyncio import copy -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union, cast +from typing import Any, Dict, Optional, Sequence, Union import kiwipy +from plumpy import loaders from plumpy.message import ( + KILL_MSG, MESSAGE_KEY, PAUSE_MSG, PLAY_MSG, STATUS_MSG, - KILL_MSG, Intent, create_continue_body, create_create_body, create_launch_body, ) - -from plumpy import loaders from plumpy.utils import PID_TYPE __all__ = [ diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index 63813bdc..00b7f1c6 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -56,7 +56,7 @@ def test_add_broadcast_subscriber(loop_communicator, subscriber): assert loop_communicator.add_broadcast_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_broadcast_subscriber(subscriber, identifier) == identifier + assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier def test_remove_broadcast_subscriber(loop_communicator, subscriber): diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 26c9a852..80c1ac71 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -7,6 +7,7 @@ import tempfile import uuid +from kiwipy.rmq.communicator import kiwipy import pytest import shortuuid import yaml @@ -81,7 +82,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): assert result == BROADCAST @pytest.mark.asyncio - async def test_broadcast_filter(self, loop_communicator): + async def test_broadcast_filter(self, loop_communicator: kiwipy.Communicator): broadcast_future = asyncio.Future() def ignore_broadcast(_comm, body, sender, subject, correlation_id): @@ -90,7 +91,7 @@ def ignore_broadcast(_comm, body, sender, subject, correlation_id): def get_broadcast(_comm, body, sender, subject, correlation_id): broadcast_future.set_result(True) - loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) + loop_communicator.add_broadcast_subscriber(ignore_broadcast, subject_filter='other') loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} diff --git a/tests/test_processes.py b/tests/test_processes.py index cc57f7dc..87b22244 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -1068,6 +1068,7 @@ def test_paused(self): self.assertSetEqual(events_tester.called, events_tester.expected_events) def test_broadcast(self): + # FIXME: here I need a mock test communicator = kiwipy.LocalCommunicator() messages = []