Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plumpy.ProcessListener made persistent support/0.21.x #277

Merged
merged 4 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/plumpy/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> None:
wrapped(self, *args, **kwargs)
self._called -= 1

# Forward wrapped function name to the decorator to show the correct name in the ``call_with_super_check``
wrapper.__name__ = wrapped.__name__
return wrapper


Expand Down
53 changes: 53 additions & 0 deletions src/plumpy/event_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Callable

from . import persistence

if TYPE_CHECKING:
sphuber marked this conversation as resolved.
Show resolved Hide resolved
from typing import Set, Type
from .process_listener import ProcessListener # pylint: disable=cyclic-import

Check warning on line 9 in src/plumpy/event_helper.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/event_helper.py#L8-L9

Added lines #L8 - L9 were not covered by tests
sphuber marked this conversation as resolved.
Show resolved Hide resolved

_LOGGER = logging.getLogger(__name__)


@persistence.auto_persist('_listeners', '_listener_type')
class EventHelper(persistence.Savable):

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

Check warning on line 28 in src/plumpy/event_helper.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/event_helper.py#L28

Added line #L28 was not covered by tests

def remove_all_listeners(self) -> None:
self._listeners.clear()

Check warning on line 31 in src/plumpy/event_helper.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/event_helper.py#L31

Added line #L31 was not covered by tests

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.

:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method

"""
if event_function is None:
raise ValueError('Must provide valid event method')

Check warning on line 46 in src/plumpy/event_helper.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/event_helper.py#L46

Added line #L46 was not covered by tests

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)

Check warning on line 53 in src/plumpy/event_helper.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/event_helper.py#L52-L53

Added lines #L52 - L53 were not covered by tests
26 changes: 24 additions & 2 deletions src/plumpy/process_listener.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# -*- coding: utf-8 -*-
import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Dict, Optional

from . import persistence
from .utils import SAVED_STATE_TYPE, protected

__all__ = ['ProcessListener']

if TYPE_CHECKING:
from .processes import Process # pylint: disable=cyclic-import


class ProcessListener(metaclass=abc.ABCMeta):
@persistence.auto_persist('_params')
class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta):

# region Persistence methods

def __init__(self) -> None:
super().__init__()
self._params: Dict[str, Any] = {}

def init(self, **kwargs: Any) -> None:
self._params = kwargs

Check warning on line 24 in src/plumpy/process_listener.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/process_listener.py#L24

Added line #L24 was not covered by tests

@protected
def load_instance_state(
self, saved_state: SAVED_STATE_TYPE, load_context: Optional[persistence.LoadSaveContext]
) -> None:
super().load_instance_state(saved_state, load_context)
self.init(**saved_state['_params'])

Check warning on line 31 in src/plumpy/process_listener.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/process_listener.py#L30-L31

Added lines #L30 - L31 were not covered by tests

# endregion

def on_process_created(self, process: 'Process') -> None:
"""
Expand Down
17 changes: 10 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .base import state_machine
from .base.state_machine import StateEntryFailed, StateMachine, TransitionFailed, event
from .base.utils import call_with_super_check, super_check
from .event_helper import EventHelper
from .process_listener import ProcessListener
from .process_spec import ProcessSpec
from .utils import PID_TYPE, SAVED_STATE_TYPE, protected
Expand Down Expand Up @@ -91,7 +92,9 @@
return func_wrapper


@persistence.auto_persist('_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status')
@persistence.auto_persist(
'_pid', '_creation_time', '_future', '_paused', '_status', '_pre_paused_status', '_event_helper'
)
class Process(StateMachine, persistence.Savable, metaclass=ProcessStateMachineMeta):
"""
The Process class is the base for any unit of work in plumpy.
Expand Down Expand Up @@ -289,7 +292,7 @@

# Runtime variables
self._future = persistence.SavableFuture(loop=self._loop)
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator

Expand Down Expand Up @@ -612,7 +615,7 @@

# Runtime variables, set initial states
self._future = persistence.SavableFuture()
self.__event_helper = utils.EventHelper(ProcessListener)
self._event_helper = EventHelper(ProcessListener)
self._logger = None
self._communicator = None

Expand Down Expand Up @@ -661,11 +664,11 @@

"""
assert (listener != self), 'Cannot listen to yourself!'
self.__event_helper.add_listener(listener)
self._event_helper.add_listener(listener)

def remove_process_listener(self, listener: ProcessListener) -> None:
"""Remove a process listener from the process."""
self.__event_helper.remove_listener(listener)
self._event_helper.remove_listener(listener)

Check warning on line 671 in src/plumpy/processes.py

View check run for this annotation

Codecov / codecov/patch

src/plumpy/processes.py#L671

Added line #L671 was not covered by tests

@protected
def set_logger(self, logger: logging.Logger) -> None:
Expand Down Expand Up @@ -778,7 +781,7 @@
"""Output is about to be emitted."""

def on_output_emitted(self, output_port: str, value: Any, dynamic: bool) -> None:
self.__event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)
self._event_helper.fire_event(ProcessListener.on_output_emitted, self, output_port, value, dynamic)

@super_check
def on_wait(self, awaitables: Sequence[Awaitable]) -> None:
Expand Down Expand Up @@ -891,7 +894,7 @@
self._closed = True

def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
self.__event_helper.fire_event(evt, self, *args, **kwargs)
self._event_helper.fire_event(evt, self, *args, **kwargs)

# endregion

Expand Down
41 changes: 0 additions & 41 deletions src/plumpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@
PID_TYPE = Hashable # pylint: disable=invalid-name


class EventHelper:

def __init__(self, listener_type: 'Type[ProcessListener]'):
assert listener_type is not None, 'Must provide valid listener type'

self._listener_type = listener_type
self._listeners: 'Set[ProcessListener]' = set()

def add_listener(self, listener: 'ProcessListener') -> None:
assert isinstance(listener, self._listener_type), 'Listener is not of right type'
self._listeners.add(listener)

def remove_listener(self, listener: 'ProcessListener') -> None:
self._listeners.discard(listener)

def remove_all_listeners(self) -> None:
self._listeners.clear()

@property
def listeners(self) -> 'Set[ProcessListener]':
return self._listeners

def fire_event(self, event_function: Callable[..., Any], *args: Any, **kwargs: Any) -> None:
"""Call an event method on all listeners.

:param event_function: the method of the ProcessListener
:param args: arguments to pass to the method
:param kwargs: keyword arguments to pass to the method

"""
if event_function is None:
raise ValueError('Must provide valid event method')

# Make a copy of the list for iteration just in case it changes in a callback
for listener in list(self.listeners):
try:
getattr(listener, event_function.__name__)(*args, **kwargs)
except Exception as exception: # pylint: disable=broad-except
_LOGGER.error("Listener '%s' produced an exception:\n%s", listener, exception)


class Frozendict(Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
Expand Down
5 changes: 3 additions & 2 deletions test/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,8 @@ def test_instance_state_with_outputs(self):
# Check that it is a copy
self.assertIsNot(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Check the contents are the same
self.assertDictEqual(outputs, bundle.get(BundleKeys.OUTPUTS, {}))
# Remove the ``ProcessSaver`` instance that is only used for testing
utils.compare_dictionaries(None, None, outputs, bundle.get(BundleKeys.OUTPUTS, {}), exclude={'_listeners'})

self.assertIsNot(proc.outputs, saver.snapshots[-1].get(BundleKeys.OUTPUTS, {}))

Expand Down Expand Up @@ -875,7 +876,7 @@ def _check_round_trip(self, proc1):
bundle2 = plumpy.Bundle(proc2)

self.assertEqual(proc1.pid, proc2.pid)
self.assertDictEqual(bundle1, bundle2)
utils.compare_dictionaries(None, None, bundle1, bundle2, exclude={'_listeners'})


class TestProcessNamespace(unittest.TestCase):
Expand Down
40 changes: 40 additions & 0 deletions test/test_workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,46 @@ def test_checkpointing(self):
if step not in ['isA', 's2', 'isB', 's3']:
self.assertTrue(finished, f'Step {step} was not called by workflow')

def test_listener_persistence(self):
persister = plumpy.InMemoryPersister()
process_finished_count = 0

class TestListener(plumpy.ProcessListener):

def on_process_finished(self, process, output):
nonlocal process_finished_count
process_finished_count += 1

class SimpleWorkChain(plumpy.WorkChain):

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(
cls.step1,
cls.step2,
)

def step1(self):
persister.save_checkpoint(self, 'step1')

def step2(self):
persister.save_checkpoint(self, 'step2')

# add SimpleWorkChain and TestListener to this module global namespace, so they can be reloaded from checkpoint
globals()['SimpleWorkChain'] = SimpleWorkChain
globals()['TestListener'] = TestListener

workchain = SimpleWorkChain()
workchain.add_process_listener(TestListener())
output = workchain.execute()

self.assertEqual(process_finished_count, 1)

workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle()
workchain_checkpoint.execute()
self.assertEqual(process_finished_count, 2)

def test_return_in_outline(self):

class WcWithReturn(WorkChain):
Expand Down
Loading
Loading