Skip to content

Commit

Permalink
Showing how using interface can avoid making change in kiwipy
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 17, 2024
1 parent 5091a17 commit 2f1dab8
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 45 deletions.
2 changes: 1 addition & 1 deletion tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import shortuuid
import yaml
from kiwipy import BroadcastFilter, rmq
from kiwipy import rmq

import plumpy
from plumpy.rmq import communications, process_comms
Expand Down
133 changes: 89 additions & 44 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,96 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Any

import kiwipy
import pytest
import shortuuid
from kiwipy import rmq

import plumpy
from plumpy.message import KILL_MSG, MESSAGE_KEY
from plumpy.rmq import process_comms

from .. import utils


@pytest.fixture
def thread_communicator():
message_exchange = f'{__file__}.{shortuuid.uuid()}'
task_exchange = f'{__file__}.{shortuuid.uuid()}'
task_queue = f'{__file__}.{shortuuid.uuid()}'
class ThreadCoordinator:
def __init__(self):
message_exchange = f'{__file__}.{shortuuid.uuid()}'
task_exchange = f'{__file__}.{shortuuid.uuid()}'
task_queue = f'{__file__}.{shortuuid.uuid()}'

self._comm = rmq.RmqThreadCommunicator.connect(
connection_params={'url': 'amqp://guest:guest@localhost:5672/'},
message_exchange=message_exchange,
task_exchange=task_exchange,
task_queue=task_queue,
)
self._comm._loop.set_debug(True)

def add_rpc_subscriber(self, subscriber, identifier=None) -> Any:
return self._comm.add_rpc_subscriber(subscriber, identifier)

def add_broadcast_subscriber(
self,
subscriber,
subject_filter=None,
identifier=None,
):
subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter)
return self._comm.add_broadcast_subscriber(subscriber, identifier)

def add_task_subscriber(self, subscriber, identifier=None):
return self._comm.add_task_subscriber(subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
return self._comm.remove_rpc_subscriber(identifier)

communicator = rmq.RmqThreadCommunicator.connect(
connection_params={'url': 'amqp://guest:guest@localhost:5672/'},
message_exchange=message_exchange,
task_exchange=task_exchange,
task_queue=task_queue,
)
communicator._loop.set_debug(True)
def remove_broadcast_subscriber(self, identifier):
return self._comm.remove_broadcast_subscriber(identifier)

yield communicator
def remove_task_subscriber(self, identifier):
return self._comm.remove_task_subscriber(identifier)

communicator.close()
def rpc_send(self, recipient_id, msg):
return self._comm.rpc_send(recipient_id, msg)

def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
return self._comm.broadcast_send(body, sender, subject, correlation_id)

def task_send(self, task, no_reply=False):
return self._comm.task_send(task, no_reply)

def close(self):
self._comm.close()


@pytest.fixture
def thread_coordinator():
coordinator = ThreadCoordinator()
yield coordinator
coordinator.close()


@pytest.fixture
def async_controller(thread_communicator: rmq.RmqThreadCommunicator):
yield process_comms.RemoteProcessController(thread_communicator)
def async_controller(thread_coordinator):
yield process_comms.RemoteProcessController(thread_coordinator)


@pytest.fixture
def sync_controller(thread_communicator: rmq.RmqThreadCommunicator):
yield process_comms.RemoteProcessThreadController(thread_communicator)
def sync_controller(thread_coordinator):
yield process_comms.RemoteProcessThreadController(thread_coordinator)


class TestRemoteProcessController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_pause(self, thread_coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
# Send a pause message
Expand All @@ -56,8 +101,8 @@ async def test_pause(self, thread_communicator, async_controller):
assert proc.paused

@pytest.mark.asyncio
async def test_play(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_play(self, thread_coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
assert proc.pause()
Expand All @@ -74,8 +119,8 @@ async def test_play(self, thread_communicator, async_controller):
await async_controller.kill_process(proc.pid)

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_kill(self, thread_coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
# Run the process in the event loop
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -87,8 +132,8 @@ async def test_kill(self, thread_communicator, async_controller):
assert proc.state == plumpy.ProcessState.KILLED

@pytest.mark.asyncio
async def test_status(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_status(self, thread_coordinator, async_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -100,15 +145,15 @@ async def test_status(self, thread_communicator, async_controller):
# make sure proc reach the final state
await async_controller.kill_process(proc.pid)

def test_broadcast(self, thread_communicator):
def test_broadcast(self, thread_coordinator):
messages = []

def on_broadcast_receive(**msg):
messages.append(msg)

thread_communicator.add_broadcast_subscriber(on_broadcast_receive)
thread_coordinator.add_broadcast_subscriber(on_broadcast_receive)

proc = utils.DummyProcess(coordinator=thread_communicator)
proc = utils.DummyProcess(coordinator=thread_coordinator)
proc.execute()

expected_subjects = []
Expand All @@ -122,8 +167,8 @@ def on_broadcast_receive(**msg):

class TestRemoteProcessThreadController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_pause(self, thread_coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)

# Send a pause message
pause_future = sync_controller.pause_process(proc.pid)
Expand All @@ -136,22 +181,22 @@ async def test_pause(self, thread_communicator, sync_controller):
assert proc.paused

@pytest.mark.asyncio
async def test_pause_all(self, thread_communicator, sync_controller):
async def test_pause_all(self, thread_coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_coordinator))

sync_controller.pause_all("Slow yo' roll")
# Wait until they are all paused
await utils.wait_util(lambda: all([proc.paused for proc in procs]))

@pytest.mark.asyncio
async def test_play_all(self, thread_communicator, sync_controller):
async def test_play_all(self, thread_coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
procs.append(proc)
proc.pause('hold tight')

Expand All @@ -161,8 +206,8 @@ async def test_play_all(self, thread_communicator, sync_controller):
await utils.wait_util(lambda: all([not proc.paused for proc in procs]))

@pytest.mark.asyncio
async def test_play(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_play(self, thread_coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
assert proc.pause()

# Send a play message
Expand All @@ -175,8 +220,8 @@ async def test_play(self, thread_communicator, sync_controller):
assert proc.state == plumpy.ProcessState.CREATED

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_kill(self, thread_coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)

# Send a kill message
kill_future = sync_controller.kill_process(proc.pid)
Expand All @@ -189,11 +234,11 @@ async def test_kill(self, thread_communicator, sync_controller):
assert proc.state == plumpy.ProcessState.KILLED

@pytest.mark.asyncio
async def test_kill_all(self, thread_communicator, sync_controller):
async def test_kill_all(self, thread_coordinator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_coordinator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

Expand All @@ -202,8 +247,8 @@ async def test_kill_all(self, thread_communicator, sync_controller):
assert all([proc.state == plumpy.ProcessState.KILLED for proc in procs])

@pytest.mark.asyncio
async def test_status(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
async def test_status(self, thread_coordinator, sync_controller):
proc = utils.WaitForSignalProcess(coordinator=thread_coordinator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand Down

0 comments on commit 2f1dab8

Please sign in to comment.