From 2f1dab860a307c949b4581f303c90bed7b55d4fd Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 16:30:15 +0100 Subject: [PATCH] Showing how using interface can avoid making change in kiwipy --- tests/rmq/test_communicator.py | 2 +- tests/rmq/test_process_comms.py | 133 +++++++++++++++++++++----------- 2 files changed, 90 insertions(+), 45 deletions(-) diff --git a/tests/rmq/test_communicator.py b/tests/rmq/test_communicator.py index 80c1ac71..ef336293 100644 --- a/tests/rmq/test_communicator.py +++ b/tests/rmq/test_communicator.py @@ -11,7 +11,7 @@ import pytest import shortuuid import yaml -from kiwipy import BroadcastFilter, rmq +from kiwipy import rmq import plumpy from plumpy.rmq import communications, process_comms diff --git a/tests/rmq/test_process_comms.py b/tests/rmq/test_process_comms.py index 9de211ee..ed718e0a 100644 --- a/tests/rmq/test_process_comms.py +++ b/tests/rmq/test_process_comms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import asyncio +from typing import Any import kiwipy import pytest @@ -7,45 +8,89 @@ from kiwipy import rmq import plumpy -from plumpy.message import KILL_MSG, MESSAGE_KEY from plumpy.rmq import process_comms from .. import utils -@pytest.fixture -def thread_communicator(): - message_exchange = f'{__file__}.{shortuuid.uuid()}' - task_exchange = f'{__file__}.{shortuuid.uuid()}' - task_queue = f'{__file__}.{shortuuid.uuid()}' +class ThreadCoordinator: + def __init__(self): + message_exchange = f'{__file__}.{shortuuid.uuid()}' + task_exchange = f'{__file__}.{shortuuid.uuid()}' + task_queue = f'{__file__}.{shortuuid.uuid()}' + + self._comm = rmq.RmqThreadCommunicator.connect( + connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, + message_exchange=message_exchange, + task_exchange=task_exchange, + task_queue=task_queue, + ) + self._comm._loop.set_debug(True) + + def add_rpc_subscriber(self, subscriber, identifier=None) -> Any: + return self._comm.add_rpc_subscriber(subscriber, identifier) + + def add_broadcast_subscriber( + self, + subscriber, + subject_filter=None, + identifier=None, + ): + subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter) + return self._comm.add_broadcast_subscriber(subscriber, identifier) + + def add_task_subscriber(self, subscriber, identifier=None): + return self._comm.add_task_subscriber(subscriber, identifier) + + def remove_rpc_subscriber(self, identifier): + return self._comm.remove_rpc_subscriber(identifier) - communicator = rmq.RmqThreadCommunicator.connect( - connection_params={'url': 'amqp://guest:guest@localhost:5672/'}, - message_exchange=message_exchange, - task_exchange=task_exchange, - task_queue=task_queue, - ) - communicator._loop.set_debug(True) + def remove_broadcast_subscriber(self, identifier): + return self._comm.remove_broadcast_subscriber(identifier) - yield communicator + def remove_task_subscriber(self, identifier): + return self._comm.remove_task_subscriber(identifier) - communicator.close() + def rpc_send(self, recipient_id, msg): + return self._comm.rpc_send(recipient_id, msg) + + def broadcast_send( + self, + body, + sender=None, + subject=None, + correlation_id=None, + ): + return self._comm.broadcast_send(body, sender, subject, correlation_id) + + def task_send(self, task, no_reply=False): + return self._comm.task_send(task, no_reply) + + def close(self): + self._comm.close() + + +@pytest.fixture +def thread_coordinator(): + coordinator = ThreadCoordinator() + yield coordinator + coordinator.close() @pytest.fixture -def async_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessController(thread_communicator) +def async_controller(thread_coordinator): + yield process_comms.RemoteProcessController(thread_coordinator) @pytest.fixture -def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): - yield process_comms.RemoteProcessThreadController(thread_communicator) +def sync_controller(thread_coordinator): + yield process_comms.RemoteProcessThreadController(thread_coordinator) class TestRemoteProcessController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, thread_coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) # Send a pause message @@ -56,8 +101,8 @@ async def test_pause(self, thread_communicator, async_controller): assert proc.paused @pytest.mark.asyncio - async def test_play(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, thread_coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) assert proc.pause() @@ -74,8 +119,8 @@ async def test_play(self, thread_communicator, async_controller): await async_controller.kill_process(proc.pid) @pytest.mark.asyncio - async def test_kill(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, thread_coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Run the process in the event loop asyncio.ensure_future(proc.step_until_terminated()) @@ -87,8 +132,8 @@ async def test_kill(self, thread_communicator, async_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_status(self, thread_communicator, async_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, thread_coordinator, async_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated()) @@ -100,15 +145,15 @@ async def test_status(self, thread_communicator, async_controller): # make sure proc reach the final state await async_controller.kill_process(proc.pid) - def test_broadcast(self, thread_communicator): + def test_broadcast(self, thread_coordinator): messages = [] def on_broadcast_receive(**msg): messages.append(msg) - thread_communicator.add_broadcast_subscriber(on_broadcast_receive) + thread_coordinator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(coordinator=thread_communicator) + proc = utils.DummyProcess(coordinator=thread_coordinator) proc.execute() expected_subjects = [] @@ -122,8 +167,8 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: @pytest.mark.asyncio - async def test_pause(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_pause(self, thread_coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Send a pause message pause_future = sync_controller.pause_process(proc.pid) @@ -136,22 +181,22 @@ async def test_pause(self, thread_communicator, sync_controller): assert proc.paused @pytest.mark.asyncio - async def test_pause_all(self, thread_communicator, sync_controller): + async def test_pause_all(self, thread_coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_coordinator)) sync_controller.pause_all("Slow yo' roll") # Wait until they are all paused await utils.wait_util(lambda: all([proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play_all(self, thread_communicator, sync_controller): + async def test_play_all(self, thread_coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) procs.append(proc) proc.pause('hold tight') @@ -161,8 +206,8 @@ async def test_play_all(self, thread_communicator, sync_controller): await utils.wait_util(lambda: all([not proc.paused for proc in procs])) @pytest.mark.asyncio - async def test_play(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_play(self, thread_coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) assert proc.pause() # Send a play message @@ -175,8 +220,8 @@ async def test_play(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.CREATED @pytest.mark.asyncio - async def test_kill(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_kill(self, thread_coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Send a kill message kill_future = sync_controller.kill_process(proc.pid) @@ -189,11 +234,11 @@ async def test_kill(self, thread_communicator, sync_controller): assert proc.state == plumpy.ProcessState.KILLED @pytest.mark.asyncio - async def test_kill_all(self, thread_communicator, sync_controller): + async def test_kill_all(self, thread_coordinator, sync_controller): """Test pausing all processes on a communicator""" procs = [] for _ in range(10): - procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator)) + procs.append(utils.WaitForSignalProcess(coordinator=thread_coordinator)) msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down') @@ -202,8 +247,8 @@ async def test_kill_all(self, thread_communicator, sync_controller): assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs]) @pytest.mark.asyncio - async def test_status(self, thread_communicator, sync_controller): - proc = utils.WaitForSignalProcess(coordinator=thread_communicator) + async def test_status(self, thread_coordinator, sync_controller): + proc = utils.WaitForSignalProcess(coordinator=thread_coordinator) # Run the process in the background asyncio.ensure_future(proc.step_until_terminated())