From f5e303657320c1d3a7ccbd40661cf476437a0ada Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Tue, 17 Dec 2024 22:20:08 +0100 Subject: [PATCH] Mock coordinator for process tests --- tests/test_processes.py | 14 ++--- tests/utils.py | 128 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 9 deletions(-) diff --git a/tests/test_processes.py b/tests/test_processes.py index 927d42b3..52a456dc 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -5,16 +5,14 @@ import enum import unittest -import kiwipy import pytest from plumpy.futures import CancellableAction -from tests import utils import plumpy from plumpy import BundleKeys, Process, ProcessState from plumpy.message import MessageBuilder from plumpy.utils import AttributesFrozendict -from tests import utils +from . import utils class ForgetToCallParent(plumpy.Process): @@ -1066,16 +1064,15 @@ def test_paused(self): self.assertSetEqual(events_tester.called, events_tester.expected_events) def test_broadcast(self): - # FIXME: here I need a mock test - communicator = kiwipy.LocalCommunicator() + coordinator = utils.MockCoordinator() messages = [] def on_broadcast_receive(_comm, body, sender, subject, correlation_id): messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id}) - communicator.add_broadcast_subscriber(on_broadcast_receive) - proc = utils.DummyProcess(coordinator=communicator) + coordinator.add_broadcast_subscriber(on_broadcast_receive) + proc = utils.DummyProcess(coordinator=coordinator) proc.execute() expected_subjects = [] @@ -1083,8 +1080,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): from_state = utils.DummyProcess.EXPECTED_STATE_SEQUENCE[i - 1].value if i != 0 else None expected_subjects.append(f'state_changed.{from_state}.{state.value}') - for i, message in enumerate(messages): - self.assertEqual(message['subject'], expected_subjects[i]) + assert [msg['subject'] for msg in messages] == expected_subjects class _RestartProcess(utils.WaitForSignalProcess): diff --git a/tests/utils.py b/tests/utils.py index 123d6e72..d67383fa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,16 +3,144 @@ import asyncio import collections +import sys +from typing import Any import unittest from collections.abc import Mapping +import concurrent.futures import plumpy from plumpy import persistence, process_states, processes, utils +from plumpy.exceptions import CoordinatorConnectionError from plumpy.message import MessageBuilder +from plumpy.rmq import TaskRejected +import shortuuid Snapshot = collections.namedtuple('Snapshot', ['state', 'bundle', 'outputs']) +class MockCoordinator: + def __init__(self): + self._task_subscribers = {} + self._broadcast_subscribers = {} + self._rpc_subscribers = {} + self._closed = False + + def is_closed(self) -> bool: + return self._closed + + def close(self): + if self._closed: + return + self._closed = True + del self._task_subscribers + del self._broadcast_subscribers + del self._rpc_subscribers + + def add_rpc_subscriber(self, subscriber, identifier=None) -> Any: + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._rpc_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + self._rpc_subscribers[identifier] = subscriber + return identifier + + def remove_rpc_subscriber(self, identifier): + self._ensure_open() + try: + self._rpc_subscribers.pop(identifier) + except KeyError as exc: + raise ValueError(f"Unknown subscriber '{identifier}'") from exc + + def add_task_subscriber(self, subscriber, identifier=None): + """ + Register a task subscriber + + :param subscriber: The task callback function + :param identifier: the subscriber identifier + """ + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._rpc_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + self._task_subscribers[identifier] = subscriber + return identifier + + def remove_task_subscriber(self, identifier): + """ + Remove a task subscriber + + :param identifier: the subscriber to remove + :raises: ValueError if identifier does not correspond to a known subscriber + """ + self._ensure_open() + try: + self._task_subscribers.pop(identifier) + except KeyError as exception: + raise ValueError(f"Unknown subscriber: '{identifier}'") from exception + + def add_broadcast_subscriber(self, subscriber, subject_filter=None, identifier=None) -> Any: + self._ensure_open() + identifier = identifier or shortuuid.uuid() + if identifier in self._broadcast_subscribers: + raise RuntimeError(f"Duplicate RPC subscriber with identifier '{identifier}'") + + self._broadcast_subscribers[identifier] = subscriber + return identifier + + def remove_broadcast_subscriber(self, identifier): + self._ensure_open() + try: + del self._broadcast_subscribers[identifier] + except KeyError as exception: + raise ValueError(f"Broadcast subscriber '{identifier}' unknown") from exception + + def task_send(self, msg, no_reply=False): + self._ensure_open() + future = concurrent.futures.Future() + + for subscriber in self._task_subscribers.values(): + try: + result = subscriber(self, msg) + future.set_result(result) + break + except TaskRejected: + pass + except Exception: # pylint: disable=broad-except + future.set_exception(RuntimeError(sys.exc_info())) + break + + if no_reply: + return None + + return future + + def rpc_send(self, recipient_id, msg): + self._ensure_open() + try: + subscriber = self._rpc_subscribers[recipient_id] + except KeyError as exception: + raise RuntimeError(f"Unknown rpc recipient '{recipient_id}'") from exception + else: + future = concurrent.futures.Future() + try: + future.set_result(subscriber(self, msg)) + except Exception: # pylint: disable=broad-except + future.set_exception(RuntimeError(sys.exc_info())) + + return future + + def broadcast_send(self, body, sender=None, subject=None, correlation_id=None): + self._ensure_open() + for subscriber in self._broadcast_subscribers.values(): + subscriber(self, body=body, sender=sender, subject=subject, correlation_id=correlation_id) + return True + + def _ensure_open(self): + if self.is_closed(): + raise CoordinatorConnectionError + + class TestCase(unittest.TestCase): pass