Skip to content

Commit

Permalink
Showing how using interface can avoid making change in kiwipy
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 17, 2024
1 parent 5091a17 commit 18de42f
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 71 deletions.
118 changes: 92 additions & 26 deletions tests/rmq/test_communications.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# -*- coding: utf-8 -*-
"""Tests for the :mod:`plumpy.rmq.communications` module."""

from typing import override
import kiwipy
from kiwipy.communications import exceptions
import pytest
from kiwipy import CommunicatorHelper

from plumpy.rmq.communications import LoopCommunicator
import shortuuid


class Subscriber:
Expand All @@ -14,21 +17,84 @@ def __call__(self):
pass


class Communicator(CommunicatorHelper):
def task_send(self, task, no_reply=False):
pass
class CoordinatorWithLoopCommunicatorHelper:
def __init__(self):
class _Communicator(kiwipy.CommunicatorHelper):
def task_send(self, task, no_reply=False):
pass

def rpc_send(self, recipient_id, msg):
pass

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None):
pass

@override
def add_broadcast_subscriber(self, subscriber, subject_filter, identifier=None):
"""Duplicate the add_broadcast_subscriber from CommunicatorHelper and add support for
passing `subject_filter`.
"""

self._ensure_open()
identifier = identifier or shortuuid.uuid()
subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter)
if identifier in self._broadcast_subscribers:
raise exceptions.DuplicateSubscriberIdentifier(f"Broadcast identifier '{identifier}'")

self._broadcast_subscribers[identifier] = subscriber
return identifier

self._comm = LoopCommunicator(_Communicator())

def add_rpc_subscriber(self, subscriber, identifier=None):
return self._comm.add_rpc_subscriber(subscriber, identifier)

def add_broadcast_subscriber(
self,
subscriber,
subject_filter=None,
identifier=None,
):
subscriber = kiwipy.BroadcastFilter(subscriber, subject=subject_filter)
return self._comm.add_broadcast_subscriber(subscriber, subject_filter, identifier)

def add_task_subscriber(self, subscriber, identifier=None):
return self._comm.add_task_subscriber(subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
return self._comm.remove_rpc_subscriber(identifier)

def remove_broadcast_subscriber(self, identifier):
return self._comm.remove_broadcast_subscriber(identifier)

def remove_task_subscriber(self, identifier):
return self._comm.remove_task_subscriber(identifier)

def rpc_send(self, recipient_id, msg):
pass
return self._comm.rpc_send(recipient_id, msg)

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None):
pass
def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
return self._comm.broadcast_send(body, sender, subject, correlation_id)

def task_send(self, task, no_reply=False):
return self._comm.task_send(task, no_reply)

def close(self):
self._comm.close()


@pytest.fixture
def loop_communicator():
def _communicator():
"""Return an instance of `LoopCommunicator`."""
return LoopCommunicator(Communicator())
coordinator = CoordinatorWithLoopCommunicatorHelper()
yield coordinator
coordinator.close()


@pytest.fixture
Expand All @@ -37,40 +103,40 @@ def subscriber():
return Subscriber()


def test_add_rpc_subscriber(loop_communicator, subscriber):
def test_add_rpc_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.add_rpc_subscriber` method."""
assert loop_communicator.add_rpc_subscriber(subscriber) is not None
assert _communicator.add_rpc_subscriber(subscriber) is not None

identifier = 'identifier'
assert loop_communicator.add_rpc_subscriber(subscriber, identifier) == identifier
assert _communicator.add_rpc_subscriber(subscriber, identifier) == identifier


def test_remove_rpc_subscriber(loop_communicator, subscriber):
def test_remove_rpc_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.remove_rpc_subscriber` method."""
identifier = loop_communicator.add_rpc_subscriber(subscriber)
loop_communicator.remove_rpc_subscriber(identifier)
identifier = _communicator.add_rpc_subscriber(subscriber)
_communicator.remove_rpc_subscriber(identifier)


def test_add_broadcast_subscriber(loop_communicator, subscriber):
def test_add_broadcast_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.add_broadcast_subscriber` method."""
assert loop_communicator.add_broadcast_subscriber(subscriber) is not None
assert _communicator.add_broadcast_subscriber(subscriber) is not None

identifier = 'identifier'
assert loop_communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier
assert _communicator.add_broadcast_subscriber(subscriber, identifier=identifier) == identifier


def test_remove_broadcast_subscriber(loop_communicator, subscriber):
def test_remove_broadcast_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.remove_broadcast_subscriber` method."""
identifier = loop_communicator.add_broadcast_subscriber(subscriber)
loop_communicator.remove_broadcast_subscriber(identifier)
identifier = _communicator.add_broadcast_subscriber(subscriber)
_communicator.remove_broadcast_subscriber(identifier)


def test_add_task_subscriber(loop_communicator, subscriber):
def test_add_task_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.add_task_subscriber` method."""
assert loop_communicator.add_task_subscriber(subscriber) is not None
assert _communicator.add_task_subscriber(subscriber) is not None


def test_remove_task_subscriber(loop_communicator, subscriber):
def test_remove_task_subscriber(_communicator, subscriber):
"""Test the `LoopCommunicator.remove_task_subscriber` method."""
identifier = loop_communicator.add_task_subscriber(subscriber)
loop_communicator.remove_task_subscriber(identifier)
identifier = _communicator.add_task_subscriber(subscriber)
_communicator.remove_task_subscriber(identifier)
2 changes: 1 addition & 1 deletion tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytest
import shortuuid
import yaml
from kiwipy import BroadcastFilter, rmq
from kiwipy import rmq

import plumpy
from plumpy.rmq import communications, process_comms
Expand Down
Loading

0 comments on commit 18de42f

Please sign in to comment.