diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index 00b7f1c6..4a585dd9 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -1,10 +1,13 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.rmq.communications` module.""" +from typing import override +import kiwipy +from kiwipy.communications import exceptions import pytest -from kiwipy import CommunicatorHelper from plumpy.rmq.communications import LoopCommunicator +import shortuuid class Subscriber: @@ -14,21 +17,84 @@ def __call__(self): pass -class Communicator(CommunicatorHelper): - def task_send(self, task, no_reply=False): - pass +class CoordinatorWithLoopCommunicatorHelper: + def __init__(self): + class _Communicator(kiwipy.CommunicatorHelper): + def task_send(self, task, no_reply=False): + pass + + def rpc_send(self, recipient_id, msg): + pass + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): + pass + + @override + def add_broadcast_subscriber(self, subscriber, subject_filter, identifier=None): + """Duplicate the add_broadcast_subscriber from CommunicatorHelper and add support for + passing `subject_filter`. + """ + + self._ensure_open() + identifier = identifier or shortuuid.uuid() + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + if identifier in self._broadcast_subscribers: + raise exceptions.DuplicateSubscriberIdentifier(f"Broadcast identifier '{identifier}'") + + self._broadcast_subscribers[identifier] = subscriber + return identifier + + self._comm = LoopCommunicator(_Communicator()) + + def add_rpc_subscriber(self, subscriber, identifier=None): + return self._comm.add_rpc_subscriber(subscriber, identifier) + + def add_broadcast_subscriber( + self, + subscriber, + subject_filter=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + return self._comm.add_broadcast_subscriber(subscriber, subject_filter, identifier) + + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) + + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) + + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) def rpc_send(self, recipient_id, msg): - pass + return self._comm.rpc_send(recipient_id, msg) - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): - pass + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + return self._comm.broadcast_send(body, sender, subject, correlation_id) + + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() @pytest.fixture -def loop_communicator(): +def _communicator(): """Return an instance of `LoopCommunicator`.""" - return LoopCommunicator(Communicator()) + coordinator = CoordinatorWithLoopCommunicatorHelper() + yield coordinator + coordinator.close() @pytest.fixture @@ -37,40 +103,40 @@ def subscriber(): return Subscriber() -def test_add_rpc_subscriber(loop_communicator, subscriber): +def test_add_rpc_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.add_rpc_subscriber` method.""" - assert loop_communicator.add_rpc_subscriber(subscriber) is not None + assert _communicator.add_rpc_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_rpc_subscriber(subscriber, identifier) == identifier + assert _communicator.add_rpc_subscriber(subscriber, identifier) == identifier -def test_remove_rpc_subscriber(loop_communicator, subscriber): +def test_remove_rpc_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.remove_rpc_subscriber` method.""" - identifier = loop_communicator.add_rpc_subscriber(subscriber) - loop_communicator.remove_rpc_subscriber(identifier) + identifier = _communicator.add_rpc_subscriber(subscriber) + _communicator.remove_rpc_subscriber(identifier) -def test_add_broadcast_subscriber(loop_communicator, subscriber): +def test_add_broadcast_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.add_broadcast_subscriber` method.""" - assert loop_communicator.add_broadcast_subscriber(subscriber) is not None + assert _communicator.add_broadcast_subscriber(subscriber) is not None identifier = 'identifier' - assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier + assert _communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier -def test_remove_broadcast_subscriber(loop_communicator, subscriber): +def test_remove_broadcast_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.remove_broadcast_subscriber` method.""" - identifier = loop_communicator.add_broadcast_subscriber(subscriber) - loop_communicator.remove_broadcast_subscriber(identifier) + identifier = _communicator.add_broadcast_subscriber(subscriber) + _communicator.remove_broadcast_subscriber(identifier) -def test_add_task_subscriber(loop_communicator, subscriber): +def test_add_task_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.add_task_subscriber` method.""" - assert loop_communicator.add_task_subscriber(subscriber) is not None + assert _communicator.add_task_subscriber(subscriber) is not None -def test_remove_task_subscriber(loop_communicator, subscriber): +def test_remove_task_subscriber(_communicator, subscriber): """Test the `LoopCommunicator.remove_task_subscriber` method.""" - identifier = loop_communicator.add_task_subscriber(subscriber) - loop_communicator.remove_task_subscriber(identifier) + identifier = _communicator.add_task_subscriber(subscriber) + _communicator.remove_task_subscriber(identifier) diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 80c1ac71..ef336293 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -11,7 +11,7 @@ import pytest import shortuuid import yaml -from kiwipy import BroadcastFilter, rmq +from kiwipy import rmq import plumpy from plumpy.rmq import communications, process_comms diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 9de211ee..18abcacf 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import asyncio +from typing import Any import kiwipy import pytest @@ -7,45 +8,89 @@ from kiwipy import rmq import plumpy -from plumpy.message import KILL_MSG, MESSAGE_KEY from plumpy.rmq import process_comms from .. import utils -@pytest.fixture -def thread_communicator(): - message_exchange = f'{__file__}.{shortuuid.uuid()}' - task_exchange = f'{__file__}.{shortuuid.uuid()}' - task_queue = f'{__file__}.{shortuuid.uuid()}' +class CoordinatorWithRmqThreadCommunicator: + def __init__(self): + message_exchange = f'{__file__}.{shortuuid.uuid()}' + task_exchange = f'{__file__}.{shortuuid.uuid()}' + task_queue = f'{__file__}.{shortuuid.uuid()}' + + self._comm = rmq.RmqThreadCommunicator.connect( + connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, + message_exchange=message_exchange, + task_exchange=task_exchange, + task_queue=task_queue, + ) + self._comm._loop.set_debug(True) + + def add_rpc_subscriber(self, subscriber, identifier=None) -> Any: + return self._comm.add_rpc_subscriber(subscriber, identifier) + + def add_broadcast_subscriber( + self, + subscriber, + subject_filter=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) - communicator = rmq.RmqThreadCommunicator.connect( - connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, - message_exchange=message_exchange, - task_exchange=task_exchange, - task_queue=task_queue, - ) - communicator._loop.set_debug(True) + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) - yield communicator + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) - communicator.close() + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + return self._comm.broadcast_send(body, sender, subject, correlation_id) + + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() + + +@pytest.fixture +def _coordinator(): + coordinator = CoordinatorWithRmqThreadCommunicator() + yield coordinator + coordinator.close() @pytest.fixture -def async_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessController(thread_communicator) +def async_controller(_coordinator): + yield process_comms.RemoteProcessController(_coordinator) @pytest.fixture -def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessThreadController(thread_communicator) +def sync_controller(_coordinator): + yield process_comms.RemoteProcessThreadController(_coordinator) class TestRemoteProcessController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -56,8 +101,8 @@ async def test_pause(self, thread_communicator, async_controller): assert proc.paused @pytest.mark.asyncio - async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -74,8 +119,8 @@ async def test_play(self, thread_communicator, async_controller): await async_controller.kill_process(proc.pid) @pytest.mark.asyncio - async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -87,8 +132,8 @@ async def test_kill(self, thread_communicator, async_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, _coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -100,15 +145,15 @@ async def test_status(self, thread_communicator, async_controller): # make sure proc reach the final state await async_controller.kill_process(proc.pid) - def test_broadcast(self, thread_communicator): + def test_broadcast(self, _coordinator): messages = [] def on_broadcast_receive(**msg): messages.append(msg) - thread_communicator.add_broadcast_subscriber(on_broadcast_receive) + _coordinator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(coordinator=thread_communicator) + proc = utils.DummyProcess(coordinator=_coordinator) proc.execute() expected_subjects = [] @@ -122,8 +167,8 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -136,22 +181,22 @@ async def test_pause(self, thread_communicator, sync_controller): assert proc.paused @pytest.mark.asyncio - async def test_pause_all(self, thread_communicator, sync_controller): + async def test_pause_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=_coordinator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused await utils.wait_util(lambda: all([proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play_all(self, thread_communicator, sync_controller): + async def test_play_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=_coordinator) procs.append(proc) proc.pause('hold tight') @@ -161,8 +206,8 @@ async def test_play_all(self, thread_communicator, sync_controller): await utils.wait_util(lambda: all([not proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) assert proc.pause() # Send a play message @@ -175,8 +220,8 @@ async def test_play(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.CREATED @pytest.mark.asyncio - async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -189,11 +234,11 @@ async def test_kill(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_kill_all(self, thread_communicator, sync_controller): + async def test_kill_all(self, _coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=_coordinator)) msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') @@ -202,8 +247,8 @@ async def test_kill_all(self, thread_communicator, sync_controller): assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio - async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, _coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated())