From aebec4f17d5d5ee48baf1f23b7ffa5d37b4bd020 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 21:00:37 +0100 Subject: [PATCH] RmqCoordinator is constructed from comm for tests --- tests/rmq/__init__.py | 50 ++++++++++++++++++ tests/rmq/test_communications.py | 88 ++++++++----------------------- tests/rmq/test_communicator.py | 82 +++++++--------------------- tests/rmq/test_process_control.py | 73 ++++++------------------- 4 files changed, 107 insertions(+), 186 deletions(-) diff --git a/tests/rmq/__init__.py b/tests/rmq/__init__.py index e69de29b..e0f263b8 100644 --- a/tests/rmq/__init__.py +++ b/tests/rmq/__init__.py @@ -0,0 +1,50 @@ + +import kiwipy + + +class RmqCoordinator: + def __init__(self, comm: kiwipy.Communicator): + self._comm = comm + + def add_rpc_subscriber(self, subscriber, identifier=None): + return self._comm.add_rpc_subscriber(subscriber, identifier) + + def add_broadcast_subscriber( + self, + subscriber, + subject_filter=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) + + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) + + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) + + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + return self._comm.broadcast_send(body, sender, subject, correlation_id) + + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() + diff --git a/tests/rmq/test_communications.py b/tests/rmq/test_communications.py index a24022fb..e45994b2 100644 --- a/tests/rmq/test_communications.py +++ b/tests/rmq/test_communications.py @@ -1,87 +1,45 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.rmq.communications` module.""" -import kiwipy import pytest +import kiwipy from plumpy.rmq.communications import LoopCommunicator +from . import RmqCoordinator -class Subscriber: - """Test class that mocks a subscriber.""" - - def __call__(self): - pass - - -class CoordinatorWithLoopCommunicatorHelper: - def __init__(self): - class _Communicator(kiwipy.CommunicatorHelper): - def task_send(self, task, no_reply=False): - pass - - def rpc_send(self, recipient_id, msg): - pass - - def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): - pass - - self._comm = LoopCommunicator(_Communicator()) - - def add_rpc_subscriber(self, subscriber, identifier=None): - return self._comm.add_rpc_subscriber(subscriber, identifier) - - def add_broadcast_subscriber( - self, - subscriber, - subject_filter=None, - identifier=None, - ): - subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) - return self._comm.add_broadcast_subscriber(subscriber, identifier) - - def add_task_subscriber(self, subscriber, identifier=None): - return self._comm.add_task_subscriber(subscriber, identifier) - - def remove_rpc_subscriber(self, identifier): - return self._comm.remove_rpc_subscriber(identifier) - - def remove_broadcast_subscriber(self, identifier): - return self._comm.remove_broadcast_subscriber(identifier) - - def remove_task_subscriber(self, identifier): - return self._comm.remove_task_subscriber(identifier) - - def rpc_send(self, recipient_id, msg): - return self._comm.rpc_send(recipient_id, msg) +@pytest.fixture +def _coordinator(): + """Return an instance of `LoopCommunicator`.""" - def broadcast_send( - self, - body, - sender=None, - subject=None, - correlation_id=None, - ): - return self._comm.broadcast_send(body, sender, subject, correlation_id) + class _Communicator(kiwipy.CommunicatorHelper): + def task_send(self, task, no_reply=False): + pass - def task_send(self, task, no_reply=False): - return self._comm.task_send(task, no_reply) + def rpc_send(self, recipient_id, msg): + pass - def close(self): - self._comm.close() + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): + pass + comm = LoopCommunicator(_Communicator()) + coordinator = RmqCoordinator(comm) -@pytest.fixture -def _coordinator(): - """Return an instance of `LoopCommunicator`.""" - coordinator = CoordinatorWithLoopCommunicatorHelper() yield coordinator + coordinator.close() @pytest.fixture def subscriber(): - """Return an instance of `Subscriber`.""" + """Return an instance of mocked `Subscriber`.""" + + class Subscriber: + """Test class that mocks a subscriber.""" + + def __call__(self): + pass + return Subscriber() diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 9dd5fa72..42c5d748 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -10,13 +10,13 @@ import shortuuid import yaml -import kiwipy from kiwipy.rmq import RmqThreadCommunicator import plumpy from plumpy.coordinator import Coordinator from plumpy.rmq import communications, process_control +from . import RmqCoordinator from .. import utils @@ -30,71 +30,27 @@ def persister(): shutil.rmtree(_tmppath) -class CoordinatorWithLoopRmqThreadCommunicator: - def __init__(self): - message_exchange = f'{__file__}.{shortuuid.uuid()}' - task_exchange = f'{__file__}.{shortuuid.uuid()}' - task_queue = f'{__file__}.{shortuuid.uuid()}' - - thread_comm = RmqThreadCommunicator.connect( - connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, - message_exchange=message_exchange, - task_exchange=task_exchange, - task_queue=task_queue, - decoder=functools.partial(yaml.load, Loader=yaml.Loader), - ) - - loop = asyncio.get_event_loop() - loop.set_debug(True) - self._comm = communications.LoopCommunicator(thread_comm, loop=loop) - - def add_rpc_subscriber(self, subscriber, identifier=None): - return self._comm.add_rpc_subscriber(subscriber, identifier) - - def add_broadcast_subscriber( - self, - subscriber, - subject_filter=None, - identifier=None, - ): - subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) - return self._comm.add_broadcast_subscriber(subscriber, identifier) - - def add_task_subscriber(self, subscriber, identifier=None): - return self._comm.add_task_subscriber(subscriber, identifier) - - def remove_rpc_subscriber(self, identifier): - return self._comm.remove_rpc_subscriber(identifier) - - def remove_broadcast_subscriber(self, identifier): - return self._comm.remove_broadcast_subscriber(identifier) - - def remove_task_subscriber(self, identifier): - return self._comm.remove_task_subscriber(identifier) - - def rpc_send(self, recipient_id, msg): - return self._comm.rpc_send(recipient_id, msg) - - def broadcast_send( - self, - body, - sender=None, - subject=None, - correlation_id=None, - ): - return self._comm.broadcast_send(body, sender, subject, correlation_id) - - def task_send(self, task, no_reply=False): - return self._comm.task_send(task, no_reply) - - def close(self): - self._comm.close() - - @pytest.fixture def _coordinator(): - coordinator = CoordinatorWithLoopRmqThreadCommunicator() + message_exchange = f'{__file__}.{shortuuid.uuid()}' + task_exchange = f'{__file__}.{shortuuid.uuid()}' + task_queue = f'{__file__}.{shortuuid.uuid()}' + + thread_comm = RmqThreadCommunicator.connect( + connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, + message_exchange=message_exchange, + task_exchange=task_exchange, + task_queue=task_queue, + decoder=functools.partial(yaml.load, Loader=yaml.Loader), + ) + + loop = asyncio.get_event_loop() + loop.set_debug(True) + comm = communications.LoopCommunicator(thread_comm, loop=loop) + coordinator = RmqCoordinator(comm) + yield coordinator + coordinator.close() diff --git a/tests/rmq/test_process_control.py b/tests/rmq/test_process_control.py index de779d78..8130d599 100644 --- a/tests/rmq/test_process_control.py +++ b/tests/rmq/test_process_control.py @@ -9,70 +9,27 @@ import plumpy from plumpy.rmq import process_control +from . import RmqCoordinator from .. import utils -class CoordinatorWithRmqThreadCommunicator: - def __init__(self): - message_exchange = f'{__file__}.{shortuuid.uuid()}' - task_exchange = f'{__file__}.{shortuuid.uuid()}' - task_queue = f'{__file__}.{shortuuid.uuid()}' - - self._comm = rmq.RmqThreadCommunicator.connect( - connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, - message_exchange=message_exchange, - task_exchange=task_exchange, - task_queue=task_queue, - ) - self._comm._loop.set_debug(True) - - def add_rpc_subscriber(self, subscriber, identifier=None): - return self._comm.add_rpc_subscriber(subscriber, identifier) - - def add_broadcast_subscriber( - self, - subscriber, - subject_filter=None, - identifier=None, - ): - subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) - return self._comm.add_broadcast_subscriber(subscriber, identifier) - - def add_task_subscriber(self, subscriber, identifier=None): - return self._comm.add_task_subscriber(subscriber, identifier) - - def remove_rpc_subscriber(self, identifier): - return self._comm.remove_rpc_subscriber(identifier) - - def remove_broadcast_subscriber(self, identifier): - return self._comm.remove_broadcast_subscriber(identifier) - - def remove_task_subscriber(self, identifier): - return self._comm.remove_task_subscriber(identifier) - - def rpc_send(self, recipient_id, msg): - return self._comm.rpc_send(recipient_id, msg) - - def broadcast_send( - self, - body, - sender=None, - subject=None, - correlation_id=None, - ): - return self._comm.broadcast_send(body, sender, subject, correlation_id) - - def task_send(self, task, no_reply=False): - return self._comm.task_send(task, no_reply) - - def close(self): - self._comm.close() - - @pytest.fixture def _coordinator(): - coordinator = CoordinatorWithRmqThreadCommunicator() + message_exchange = f'{__file__}.{shortuuid.uuid()}' + task_exchange = f'{__file__}.{shortuuid.uuid()}' + task_queue = f'{__file__}.{shortuuid.uuid()}' + + comm = rmq.RmqThreadCommunicator.connect( + connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, + message_exchange=message_exchange, + task_exchange=task_exchange, + task_queue=task_queue, + ) + comm._loop.set_debug(True) + coordinator = RmqCoordinator(comm) + yield coordinator + coordinator.close()