From c2e62981ae567d2d5401160440b9b30225cfbaf0 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 27 Nov 2024 14:08:14 +0100 Subject: [PATCH] Use ruff and use aiida-core ruff config --- .github/workflows/validate_release_tag.py | 8 ++- .pre-commit-config.yaml | 59 ++++++++----------- docs/source/conf.py | 32 ++++++---- examples/process_helloworld.py | 1 - examples/process_launch.py | 1 - examples/process_wait_and_resume.py | 10 +--- examples/workchain_simple.py | 1 - pyproject.toml | 35 +++++++++-- src/plumpy/__init__.py | 18 ++++-- src/plumpy/base/__init__.py | 2 +- src/plumpy/base/state_machine.py | 28 ++++----- src/plumpy/communications.py | 17 ++++-- src/plumpy/event_helper.py | 1 - src/plumpy/events.py | 10 +++- src/plumpy/futures.py | 3 +- src/plumpy/lang.py | 3 +- src/plumpy/mixins.py | 1 + src/plumpy/persistence.py | 29 ++++----- src/plumpy/ports.py | 37 +++++------- src/plumpy/process_comms.py | 42 +++++--------- src/plumpy/process_listener.py | 1 - src/plumpy/process_spec.py | 7 ++- src/plumpy/process_states.py | 25 ++++---- src/plumpy/processes.py | 65 ++++++++++++--------- src/plumpy/utils.py | 21 +++++-- src/plumpy/workchains.py | 27 +++------ test/base/test_statemachine.py | 8 +-- test/base/test_utils.py | 6 -- test/conftest.py | 1 + test/persistence/test_inmemory.py | 19 +++--- test/persistence/test_pickle.py | 20 +++---- test/rmq/test_communicator.py | 34 ++++------- test/rmq/test_process_comms.py | 9 +-- test/test_communications.py | 5 +- test/test_events.py | 2 +- test/test_expose.py | 32 +++------- test/test_lang.py | 13 ----- test/test_loaders.py | 8 ++- test/test_persistence.py | 9 +-- test/test_port.py | 7 +-- test/test_process_comms.py | 14 ++--- test/test_process_spec.py | 1 - test/test_processes.py | 71 +++++------------------ test/test_utils.py | 7 --- test/test_waiting_process.py | 1 - test/test_workchains.py | 63 +++++++++----------- test/utils.py | 45 ++++++-------- 47 files changed, 376 insertions(+), 483 deletions(-) diff --git a/.github/workflows/validate_release_tag.py b/.github/workflows/validate_release_tag.py index bdd35537..4caf68b8 100644 --- a/.github/workflows/validate_release_tag.py +++ b/.github/workflows/validate_release_tag.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Validate that the version in the tag label matches the version of the package.""" + import argparse import ast from pathlib import Path @@ -17,8 +18,11 @@ def get_version_from_module(content: str) -> str: try: return next( - ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) - for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ast.literal_eval(statement.value) + for statement in module.body + if isinstance(statement, ast.Assign) + for target in statement.targets + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration as exception: raise IOError('Unable to find the `__version__` attribute in the module.') from exception diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cae9888f..d90a06a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,46 +1,35 @@ repos: -- repo: https://github.com/pre-commit/pre-commit-hooks + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: - - id: double-quote-string-fixer - - id: end-of-file-fixer - - id: fix-encoding-pragma - - id: mixed-line-ending - - id: trailing-whitespace + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-encoding-pragma + - id: mixed-line-ending + - id: trailing-whitespace -- repo: https://github.com/ikamensh/flynt/ - rev: '0.77' + - repo: https://github.com/ikamensh/flynt/ + rev: 1.0.1 hooks: - - id: flynt + - id: flynt + args: [--line-length=120, --fail-on-change] -- repo: https://github.com/pycqa/isort - rev: '5.12.0' + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.0 hooks: - - id: isort + - id: ruff-format + exclude: &exclude_ruff > + (?x)^( -- repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 - hooks: - - id: yapf - name: yapf - types: [python] - args: ['-i'] - additional_dependencies: ['toml'] + )$ -- repo: https://github.com/PyCQA/pylint - rev: v2.15.8 - hooks: - - id: pylint - language: system - exclude: > - (?x)^( - docs/source/conf.py| - test/.*| - )$ + - id: ruff + exclude: *exclude_ruff + args: [--fix, --exit-non-zero-on-fix, --show-fixes] -- repo: local + - repo: local hooks: - - id: mypy + - id: mypy name: mypy entry: mypy args: [--config-file=pyproject.toml] @@ -49,6 +38,6 @@ repos: require_serial: true pass_filenames: true files: >- - (?x)^( - src/.*py| - )$ + (?x)^( + src/.*py| + )$ diff --git a/docs/source/conf.py b/docs/source/conf.py index a1c6f26e..b1a2a019 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,11 +8,9 @@ import filecmp import os -from pathlib import Path import shutil -import subprocess -import sys import tempfile +from pathlib import Path import plumpy @@ -32,8 +30,12 @@ master_doc = 'index' language = None extensions = [ - 'myst_nb', 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx', - 'IPython.sphinxext.ipython_console_highlighting' + 'myst_nb', + 'sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.viewcode', + 'sphinx.ext.intersphinx', + 'IPython.sphinxext.ipython_console_highlighting', ] # List of patterns, relative to source directory, that match files and @@ -46,14 +48,14 @@ intersphinx_mapping = { 'python': ('https://docs.python.org/3.8', None), - 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None) + 'kiwipy': ('https://kiwipy.readthedocs.io/en/latest/', None), } myst_enable_extensions = ['colon_fence', 'deflist', 'html_image', 'smartquotes', 'substitution'] myst_url_schemes = ('http', 'https', 'mailto') myst_substitutions = { 'rabbitmq': '[RabbitMQ](https://www.rabbitmq.com/)', - 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)' + 'kiwipy': '[kiwipy](https://kiwipy.readthedocs.io)', } jupyter_execute_notebooks = 'cache' execution_show_tb = 'READTHEDOCS' in os.environ @@ -84,7 +86,7 @@ 'use_issues_button': True, 'path_to_docs': 'docs', 'use_edit_page_button': True, - 'extra_navbar': '' + 'extra_navbar': '', } # API Documentation @@ -112,9 +114,17 @@ def run_apidoc(app): # this ensures that document rebuilds are not triggered every time (due to change in file mtime) with tempfile.TemporaryDirectory() as tmpdirname: options = [ - '-o', tmpdirname, - str(package_dir), '--private', '--force', '--module-first', '--separate', '--no-toc', '--maxdepth', '4', - '-q' + '-o', + tmpdirname, + str(package_dir), + '--private', + '--force', + '--module-first', + '--separate', + '--no-toc', + '--maxdepth', + '4', + '-q', ] os.environ['SPHINX_APIDOC_OPTIONS'] = 'members,special-members,private-members,undoc-members,show-inheritance' diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index cf043eba..db2eff0f 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -3,7 +3,6 @@ class HelloWorld(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) diff --git a/examples/process_launch.py b/examples/process_launch.py index 645af0fd..3aa46fdc 100644 --- a/examples/process_launch.py +++ b/examples/process_launch.py @@ -4,7 +4,6 @@ import tempfile import kiwipy - import plumpy diff --git a/examples/process_wait_and_resume.py b/examples/process_wait_and_resume.py index 03e8b57a..f92fb2f7 100644 --- a/examples/process_wait_and_resume.py +++ b/examples/process_wait_and_resume.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- -from kiwipy import rmq - import plumpy +from kiwipy import rmq class WaitForResumeProc(plumpy.Process): - def run(self): print(f'Now I am running: {self.state}') return plumpy.Wait(self.after_resume_and_exec) @@ -15,12 +13,10 @@ def after_resume_and_exec(self): kwargs = { - 'connection_params': { - 'url': 'amqp://guest:guest@127.0.0.1:5672/' - }, + 'connection_params': {'url': 'amqp://guest:guest@127.0.0.1:5672/'}, 'message_exchange': 'WaitForResume.uuid-0', 'task_exchange': 'WaitForResume.uuid-0', - 'task_queue': 'WaitForResume.uuid-0' + 'task_queue': 'WaitForResume.uuid-0', } if __name__ == '__main__': diff --git a/examples/workchain_simple.py b/examples/workchain_simple.py index 078de3ca..aa189d3b 100644 --- a/examples/workchain_simple.py +++ b/examples/workchain_simple.py @@ -3,7 +3,6 @@ class AddAndMulWF(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/pyproject.toml b/pyproject.toml index 2d38516d..2cfbe512 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,11 +80,36 @@ exclude = [ line-length = 120 fail-on-change = true -[tool.isort] -force_sort_within_sections = true -include_trailing_comma = true -line_length = 120 -multi_line_output = 3 +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'single' + +[tool.ruff.lint] +ignore = [ + 'F403', # Star imports unable to detect undefined names + 'F405', # Import may be undefined or defined from star imports + 'PLR0911', # Too many return statements + 'PLR0912', # Too many branches + 'PLR0913', # Too many arguments in function definition + 'PLR0915', # Too many statements + 'PLR2004', # Magic value used in comparison + 'RUF005', # Consider iterable unpacking instead of concatenation + 'RUF012' # Mutable class attributes should be annotated with `typing.ClassVar` +] +select = [ + 'E', # pydocstyle + 'W', # pydocstyle + 'F', # pyflakes + 'I', # isort + 'N', # pep8-naming + 'PLC', # pylint-convention + 'PLE', # pylint-error + 'PLR', # pylint-refactor + 'PLW', # pylint-warning + 'RUF' # ruff +] [tool.mypy] show_error_codes = true diff --git a/src/plumpy/__init__.py b/src/plumpy/__init__.py index ea88f872..64a304c9 100644 --- a/src/plumpy/__init__.py +++ b/src/plumpy/__init__.py @@ -21,9 +21,20 @@ from .workchains import * __all__ = ( - events.__all__ + exceptions.__all__ + processes.__all__ + utils.__all__ + futures.__all__ + mixins.__all__ + - persistence.__all__ + communications.__all__ + process_comms.__all__ + process_listener.__all__ + - workchains.__all__ + loaders.__all__ + ports.__all__ + process_states.__all__ + events.__all__ + + exceptions.__all__ + + processes.__all__ + + utils.__all__ + + futures.__all__ + + mixins.__all__ + + persistence.__all__ + + communications.__all__ + + process_comms.__all__ + + process_listener.__all__ + + workchains.__all__ + + loaders.__all__ + + ports.__all__ + + process_states.__all__ ) @@ -32,7 +43,6 @@ # https://docs.python.org/3.1/library/logging.html#library-config # for more details class NullHandler(logging.Handler): - def emit(self, record: logging.LogRecord) -> None: pass diff --git a/src/plumpy/base/__init__.py b/src/plumpy/base/__init__.py index 79450590..42b150ec 100644 --- a/src/plumpy/base/__init__.py +++ b/src/plumpy/base/__init__.py @@ -4,4 +4,4 @@ from .state_machine import * from .utils import * -__all__ = (state_machine.__all__ + utils.__all__) +__all__ = state_machine.__all__ + utils.__all__ diff --git a/src/plumpy/base/state_machine.py b/src/plumpy/base/state_machine.py index b62825e1..4e7a5722 100644 --- a/src/plumpy/base/state_machine.py +++ b/src/plumpy/base/state_machine.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The state machine for processes""" + import enum import functools import inspect @@ -42,7 +43,6 @@ class InvalidStateError(Exception): class EventError(StateMachineError): - def __init__(self, evt: str, msg: str): super().__init__(msg) self.event = evt @@ -52,10 +52,7 @@ class TransitionFailed(Exception): """A state transition failed""" def __init__( - self, - initial_state: 'State', - final_state: Optional['State'] = None, - traceback_str: Optional[str] = None + self, initial_state: 'State', final_state: Optional['State'] = None, traceback_str: Optional[str] = None ) -> None: self.initial_state = initial_state self.final_state = final_state @@ -71,7 +68,7 @@ def _format_msg(self) -> str: def event( from_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', - to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*' + to_states: Union[str, Type['State'], Iterable[Type['State']]] = '*', ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """A decorator to check for correct transitions, raising ``EventError`` on invalid transitions.""" if from_states != '*': @@ -102,8 +99,8 @@ def transition(self: Any, *a: Any, **kw: Any) -> Any: raise EventError(evt_label, 'Machine did not transition') raise EventError( - evt_label, 'Event produced invalid state transition from ' - f'{initial.LABEL} to {self._state.LABEL}' + evt_label, + 'Event produced invalid state transition from ' f'{initial.LABEL} to {self._state.LABEL}', ) return result @@ -138,12 +135,12 @@ def __str__(self) -> str: @property def label(self) -> LABEL_TYPE: - """ Convenience property to get the state label """ + """Convenience property to get the state label""" return self.LABEL @super_check def enter(self) -> None: - """ Entering the state """ + """Entering the state""" def execute(self) -> Optional['State']: """ @@ -153,7 +150,7 @@ def execute(self) -> Optional['State']: @super_check def exit(self) -> None: - """ Exiting the state """ + """Exiting the state""" if self.is_terminal(): raise InvalidStateError(f'Cannot exit a terminal state {self.LABEL}') @@ -175,13 +172,13 @@ class StateEventHook(enum.Enum): procedure. The callback will be passed a state instance whose meaning will differ depending on the hook as commented below. """ + ENTERING_STATE: int = 0 # State passed will be the state that is being entered ENTERED_STATE: int = 1 # State passed will be the last state that we entered from EXITING_STATE: int = 2 # State passed will be the next state that will be entered (or None for terminal) class StateMachineMeta(type): - def __call__(cls, *args: Any, **kwargs: Any) -> 'StateMachine': """ Create the state machine and enter the initial state. @@ -301,11 +298,10 @@ def _fire_state_event(self, hook: Hashable, state: Optional[State]) -> None: @super_check def on_terminated(self) -> None: - """ Called when a terminal state is entered """ + """Called when a terminal state is entered""" def transition_to(self, new_state: Union[Hashable, State, Type[State]], *args: Any, **kwargs: Any) -> None: - assert not self._transitioning, \ - 'Cannot call transition_to when already transitioning state' + assert not self._transitioning, 'Cannot call transition_to when already transitioning state' initial_state_label = self._state.LABEL if self._state is not None else None label = None @@ -365,7 +361,7 @@ def create_state(self, state_label: Hashable, *args: Any, **kwargs: Any) -> Stat raise ValueError(f'{state_label} is not a valid state') def _exit_current_state(self, next_state: State) -> None: - """ Exit the given state """ + """Exit the given state""" # If we're just being constructed we may not have a state yet to exit, # in which case check the new state is the initial state diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index 51dff60d..f4941d0f 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for general kiwipy communication methods""" + import asyncio import functools from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional @@ -10,7 +11,12 @@ from .utils import ensure_coroutine __all__ = [ - 'Communicator', 'RemoteException', 'DeliveryFailed', 'TaskRejected', 'plum_to_kiwi_future', 'wrap_communicator' + 'Communicator', + 'RemoteException', + 'DeliveryFailed', + 'TaskRejected', + 'plum_to_kiwi_future', + 'wrap_communicator', ] RemoteException = kiwipy.RemoteException @@ -55,8 +61,9 @@ def on_done(_plum_future: futures.Future) -> None: return kiwi_future -def convert_to_comm(callback: 'Subscriber', - loop: Optional[asyncio.AbstractEventLoop] = None) -> Callable[..., kiwipy.Future]: +def convert_to_comm( + callback: 'Subscriber', loop: Optional[asyncio.AbstractEventLoop] = None +) -> Callable[..., kiwipy.Future]: """ Take a callback function and converted it to one that will schedule a callback on the given even loop and return a kiwi future representing the future outcome @@ -67,7 +74,6 @@ def convert_to_comm(callback: 'Subscriber', :return: a new callback function that returns a future """ if isinstance(callback, kiwipy.BroadcastFilter): - # if the broadcast is filtered for this callback, # we don't want to go through the (costly) process # of setting up async tasks and callbacks @@ -84,7 +90,6 @@ def _passthrough(*args: Any, **kwargs: Any) -> bool: # pylint: disable=unused-a coro = ensure_coroutine(callback) def converted(communicator: kiwipy.Communicator, *args: Any, **kwargs: Any) -> kiwipy.Future: - if _passthrough(*args, **kwargs): kiwi_future = kiwipy.Future() kiwi_future.set_result(None) @@ -170,7 +175,7 @@ def broadcast_send( body: Optional[Any], sender: Optional[str] = None, subject: Optional[str] = None, - correlation_id: Optional['ID_TYPE'] = None + correlation_id: Optional['ID_TYPE'] = None, ) -> futures.Future: return self._communicator.broadcast_send(body, sender, subject, correlation_id) diff --git a/src/plumpy/event_helper.py b/src/plumpy/event_helper.py index 3a342321..2ff73597 100644 --- a/src/plumpy/event_helper.py +++ b/src/plumpy/event_helper.py @@ -14,7 +14,6 @@ @persistence.auto_persist('_listeners', '_listener_type') class EventHelper(persistence.Savable): - def __init__(self, listener_type: 'Type[ProcessListener]'): assert listener_type is not None, 'Must provide valid listener type' diff --git a/src/plumpy/events.py b/src/plumpy/events.py index 60a5306e..79bf3440 100644 --- a/src/plumpy/events.py +++ b/src/plumpy/events.py @@ -1,12 +1,18 @@ # -*- coding: utf-8 -*- """Event and loop related classes and functions""" + import asyncio import sys from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence __all__ = [ - 'new_event_loop', 'set_event_loop', 'get_event_loop', 'run_until_complete', 'set_event_loop_policy', - 'reset_event_loop_policy', 'PlumpyEventLoopPolicy' + 'new_event_loop', + 'set_event_loop', + 'get_event_loop', + 'run_until_complete', + 'set_event_loop_policy', + 'reset_event_loop_policy', + 'PlumpyEventLoopPolicy', ] if TYPE_CHECKING: diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 365b8008..b25cfa79 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -2,6 +2,7 @@ """ Module containing future related methods and classes """ + import asyncio from typing import Any, Callable, Coroutine, Optional @@ -35,7 +36,7 @@ def __init__(self, action: Callable[..., Any], cookie: Any = None): @property def cookie(self) -> Any: - """ A cookie that can be used to correlate the actions with something """ + """A cookie that can be used to correlate the actions with something""" return self._cookie def run(self, *args: Any, **kwargs: Any) -> None: diff --git a/src/plumpy/lang.py b/src/plumpy/lang.py index 6d9290af..9672d500 100644 --- a/src/plumpy/lang.py +++ b/src/plumpy/lang.py @@ -2,13 +2,13 @@ """ Python language utilities and tools. """ + import functools import inspect from typing import Any, Callable def protected(check: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - def wrap(func: Callable[..., Any]) -> Callable[..., Any]: if isinstance(func, property): raise RuntimeError('Protected must go after @property decorator') @@ -68,7 +68,6 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: class __NULL: # pylint: disable=invalid-name - def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) diff --git a/src/plumpy/mixins.py b/src/plumpy/mixins.py index a8dcca1e..10142eb7 100644 --- a/src/plumpy/mixins.py +++ b/src/plumpy/mixins.py @@ -12,6 +12,7 @@ class ContextMixin(persistence.Savable): Add a context to a Process. The contents of the context will be saved in the instance state unlike standard instance variables. """ + CONTEXT: str = '_context' def __init__(self, *args: Any, **kwargs: Any): diff --git a/src/plumpy/persistence.py b/src/plumpy/persistence.py index 7a15b1cc..5598c147 100644 --- a/src/plumpy/persistence.py +++ b/src/plumpy/persistence.py @@ -18,8 +18,15 @@ from .utils import PID_TYPE, SAVED_STATE_TYPE __all__ = [ - 'Bundle', 'Persister', 'PicklePersister', 'auto_persist', 'Savable', 'SavableFuture', 'LoadSaveContext', - 'PersistedCheckpoint', 'InMemoryPersister' + 'Bundle', + 'Persister', + 'PicklePersister', + 'auto_persist', + 'Savable', + 'SavableFuture', + 'LoadSaveContext', + 'PersistedCheckpoint', + 'InMemoryPersister', ] PersistedCheckpoint = collections.namedtuple('PersistedCheckpoint', ['pid', 'tag']) @@ -29,7 +36,6 @@ class Bundle(dict): - def __init__(self, savable: 'Savable', save_context: Optional['LoadSaveContext'] = None, dereference: bool = False): """ Create a bundle from a savable. Optionally keep information about the @@ -77,7 +83,6 @@ def _bundle_constructor(loader: yaml.Loader, data: Any) -> Generator[Bundle, Non class Persister(metaclass=abc.ABCMeta): - @abc.abstractmethod def save_checkpoint(self, process: 'Process', tag: Optional[str] = None) -> None: """ @@ -301,7 +306,7 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: class InMemoryPersister(Persister): - """ Mainly to be used in testing/debugging """ + """Mainly to be used in testing/debugging""" def __init__(self, loader: Optional[loaders.ObjectLoader] = None) -> None: super().__init__() @@ -344,7 +349,6 @@ def delete_process_checkpoints(self, pid: PID_TYPE) -> None: def auto_persist(*members: str) -> Callable[[SavableClsType], SavableClsType]: - def wrapped(savable: SavableClsType) -> SavableClsType: # pylint: disable=protected-access if savable._auto_persist is None: @@ -390,7 +394,6 @@ def _ensure_object_loader(context: Optional['LoadSaveContext'], saved_state: SAV class LoadSaveContext: - def __init__(self, loader: Optional[loaders.ObjectLoader] = None, **kwargs: Any) -> None: self._values = dict(**kwargs) self.loader = loader @@ -408,7 +411,7 @@ def __contains__(self, item: Any) -> bool: return self._values.__contains__(item) def copyextend(self, **kwargs: Any) -> 'LoadSaveContext': - """ Add additional information to the context by making a copy with the new values """ + """Add additional information to the context by making a copy with the new values""" extended = self._values.copy() extended.update(kwargs) loader = extended.pop('loader', self.loader) @@ -527,10 +530,7 @@ def save_members(self, members: Iterable[str], out_state: SAVED_STATE_TYPE) -> N out_state[member] = value def load_members( - self, - members: Iterable[str], - saved_state: SAVED_STATE_TYPE, - load_context: Optional[LoadSaveContext] = None + self, members: Iterable[str], saved_state: SAVED_STATE_TYPE, load_context: Optional[LoadSaveContext] = None ) -> None: for member in members: setattr(self, member, self._get_value(saved_state, member, load_context)) @@ -580,8 +580,9 @@ def _get_meta_type(saved_state: SAVED_STATE_TYPE, name: str) -> Any: # endregion - def _get_value(self, saved_state: SAVED_STATE_TYPE, name: str, - load_context: Optional[LoadSaveContext]) -> Union[MethodType, 'Savable']: + def _get_value( + self, saved_state: SAVED_STATE_TYPE, name: str, load_context: Optional[LoadSaveContext] + ) -> Union[MethodType, 'Savable']: value = saved_state[name] typ = Savable._get_meta_type(saved_state, name) diff --git a/src/plumpy/ports.py b/src/plumpy/ports.py index fc5f138f..3ddc6554 100644 --- a/src/plumpy/ports.py +++ b/src/plumpy/ports.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- """Module for process ports""" + import collections import copy import inspect import json import logging -from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast import warnings +from typing import Any, Callable, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Type, Union, cast from plumpy.utils import AttributesFrozendict, is_mutable_property, type_check @@ -68,7 +69,7 @@ def __init__( valid_type: Optional[Type[Any]] = None, help: Optional[str] = None, # pylint: disable=redefined-builtin required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None + validator: Optional[VALIDATOR_TYPE] = None, ) -> None: self._name = name self._valid_type = valid_type @@ -236,14 +237,14 @@ def __init__( help: Optional[str] = None, # pylint: disable=redefined-builtin default: Any = UNSPECIFIED, required: bool = True, - validator: Optional[VALIDATOR_TYPE] = None + validator: Optional[VALIDATOR_TYPE] = None, ) -> None: # pylint: disable=too-many-arguments super().__init__( name, valid_type=valid_type, help=help, required=InputPort.required_override(required, default), - validator=validator + validator=validator, ) if required is not InputPort.required_override(required, default): @@ -252,7 +253,6 @@ def __init__( ) if default is not UNSPECIFIED: - # Only validate the default value if it is not a callable. If it is a callable its return value will always # be validated when the port is validated upon process construction, if the default is was actually used. if not callable(default): @@ -310,7 +310,7 @@ def __init__( valid_type: Optional[Type[Any]] = None, default: Any = UNSPECIFIED, dynamic: bool = False, - populate_defaults: bool = True + populate_defaults: bool = True, ) -> None: # pylint: disable=too-many-arguments """Construct a port namespace. @@ -459,7 +459,7 @@ def get_port(self, name: str, create_dynamically: bool = False) -> Union[Port, ' valid_type=self.valid_type, default=self.default, dynamic=self.dynamic, - populate_defaults=self.populate_defaults + populate_defaults=self.populate_defaults, ) if namespace: @@ -495,7 +495,6 @@ def create_port_namespace(self, name: str, **kwargs: Any) -> 'PortNamespace': # If this is True, the (sub) port namespace does not yet exist, so we create it if port_name not in self: - # If there still is a `namespace`, we create a sub namespace, *without* the constructor arguments if namespace: self[port_name] = self.__class__(port_name) @@ -515,7 +514,7 @@ def absorb( port_namespace: 'PortNamespace', exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[Dict[str, Any]] = None + namespace_options: Optional[Dict[str, Any]] = None, ) -> List[str]: """Absorb another PortNamespace instance into oneself, including all its mutable properties and ports. @@ -559,14 +558,12 @@ def absorb( absorbed_ports = [] for port_name, port in port_namespace.items(): - # If the current port name occurs in the exclude list, simply skip it entirely, there is no need to consider # any of the nested ports it might have, even if it is a port namespace if exclude and port_name in exclude: continue if isinstance(port, PortNamespace): - # If the name does not appear at the start of any of the include rules we continue: if include and not any(rule.startswith(port_name) for rule in include): continue @@ -616,9 +613,7 @@ def project(self, port_values: MutableMapping[str, Any]) -> MutableMapping[str, return result def validate( # pylint: disable=arguments-differ - self, - port_values: Optional[Mapping[str, Any]] = None, - breadcrumbs: Sequence[str] = () + self, port_values: Optional[Mapping[str, Any]] = None, breadcrumbs: Sequence[str] = () ) -> Optional[PortValidationError]: """ Validate the namespace port itself and subsequently all the port_values it contains @@ -669,8 +664,9 @@ def validate( # pylint: disable=arguments-differ else: message = self.validator(port_values_clone, self) # pylint: disable=not-callable if message is not None: - assert isinstance(message, str), \ - f"Validator returned something other than None or str: '{type(message)}'" + assert isinstance( + message, str + ), f"Validator returned something other than None or str: '{type(message)}'" return PortValidationError(message, breadcrumbs_to_port(breadcrumbs_local)) return None @@ -682,14 +678,12 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen :return: an AttributesFrozenDict with pre-processed port value mapping, complemented with port default values """ for name, port in self.items(): - # If the port was not specified in the inputs values and the port is a namespace with the property # `populate_defaults=False`, we skip the pre-processing and do not populate defaults. if name not in port_values and isinstance(port, PortNamespace) and not port.populate_defaults: continue if name not in port_values: - if port.has_default(): default = port.default if callable(default): @@ -712,8 +706,9 @@ def pre_process(self, port_values: MutableMapping[str, Any]) -> AttributesFrozen return AttributesFrozendict(port_values) - def validate_ports(self, port_values: MutableMapping[str, Any], - breadcrumbs: Sequence[str]) -> Optional[PortValidationError]: + def validate_ports( + self, port_values: MutableMapping[str, Any], breadcrumbs: Sequence[str] + ) -> Optional[PortValidationError]: """ Validate port values with respect to the explicitly defined ports of the port namespace. Ports values that are matched to an actual Port will be popped from the dictionary @@ -791,7 +786,7 @@ def strip_namespace(namespace: str, separator: str, rules: Optional[Sequence[str for rule in rules: if rule.startswith(prefix): - stripped.append(rule[len(prefix):]) + stripped.append(rule[len(prefix) :]) return stripped diff --git a/src/plumpy/process_comms.py b/src/plumpy/process_comms.py index c66e8431..c1588a27 100644 --- a/src/plumpy/process_comms.py +++ b/src/plumpy/process_comms.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Module for process level communication functions and classes""" + import asyncio import copy import logging @@ -34,6 +35,7 @@ class Intent: """Intent constants for a process message""" + # pylint: disable=too-few-public-methods PLAY: str = 'play' PAUSE: str = 'pause' @@ -71,7 +73,7 @@ def create_launch_body( init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, - nowait: bool = True + nowait: bool = True, ) -> Dict[str, Any]: """ Create a message body for the launch action @@ -95,8 +97,8 @@ def create_launch_body( PERSIST_KEY: persist, NOWAIT_KEY: nowait, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -119,7 +121,7 @@ def create_create_body( init_args: Optional[Sequence[Any]] = None, init_kwargs: Optional[Dict[str, Any]] = None, persist: bool = False, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> Dict[str, Any]: """ Create a message body to create a new process @@ -140,8 +142,8 @@ def create_create_body( PROCESS_CLASS_KEY: loader.identify_object(process_class), PERSIST_KEY: persist, ARGS_KEY: init_args, - KWARGS_KEY: init_kwargs - } + KWARGS_KEY: init_kwargs, + }, } return msg_body @@ -216,11 +218,7 @@ async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'Pro return result async def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Optional['ProcessResult']: """ Continue the process @@ -249,7 +247,7 @@ async def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Launch a process given the class and constructor arguments @@ -281,7 +279,7 @@ async def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> 'ProcessResult': """ Execute a process. This call will first send a create task and then a continue task over @@ -399,11 +397,7 @@ def kill_all(self, msg: Optional[Any]) -> None: self._communicator.broadcast_send(msg, subject=Intent.KILL) def continue_process( - self, - pid: 'PID_TYPE', - tag: Optional[str] = None, - nowait: bool = False, - no_reply: bool = False + self, pid: 'PID_TYPE', tag: Optional[str] = None, nowait: bool = False, no_reply: bool = False ) -> Union[None, PID_TYPE, ProcessResult]: message = create_continue_body(pid=pid, tag=tag, nowait=nowait) return self.task_send(message, no_reply=no_reply) @@ -416,7 +410,7 @@ def launch_process( persist: bool = False, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: # pylint: disable=too-many-arguments """ @@ -441,7 +435,7 @@ def execute_process( init_kwargs: Optional[Dict[str, Any]] = None, loader: Optional[loaders.ObjectLoader] = None, nowait: bool = False, - no_reply: bool = False + no_reply: bool = False, ) -> Union[None, PID_TYPE, ProcessResult]: """ Execute a process. This call will first send a create task and then a continue task over @@ -512,7 +506,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, persister: Optional[persistence.Persister] = None, load_context: Optional[persistence.LoadSaveContext] = None, - loader: Optional[loaders.ObjectLoader] = None + loader: Optional[loaders.ObjectLoader] = None, ) -> None: self._loop = loop self._persister = persister @@ -581,11 +575,7 @@ async def _launch( return proc.future().result() async def _continue( - self, - _communicator: kiwipy.Communicator, - pid: 'PID_TYPE', - nowait: bool, - tag: Optional[str] = None + self, _communicator: kiwipy.Communicator, pid: 'PID_TYPE', nowait: bool, tag: Optional[str] = None ) -> Union[PID_TYPE, ProcessResult]: """ Continue the process diff --git a/src/plumpy/process_listener.py b/src/plumpy/process_listener.py index 110394a2..d49b8994 100644 --- a/src/plumpy/process_listener.py +++ b/src/plumpy/process_listener.py @@ -13,7 +13,6 @@ @persistence.auto_persist('_params') class ProcessListener(persistence.Savable, metaclass=abc.ABCMeta): - # region Persistence methods def __init__(self) -> None: diff --git a/src/plumpy/process_spec.py b/src/plumpy/process_spec.py index c82d59ee..4cb81196 100644 --- a/src/plumpy/process_spec.py +++ b/src/plumpy/process_spec.py @@ -22,6 +22,7 @@ class ProcessSpec: Every Process class has one of these. """ + NAME_INPUTS_PORT_NAMESPACE: str = 'inputs' NAME_OUTPUTS_PORT_NAMESPACE: str = 'outputs' PORT_NAMESPACE_TYPE = PortNamespace @@ -184,7 +185,7 @@ def expose_inputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the inputs from another Process to this ProcessSpec. @@ -215,7 +216,7 @@ def expose_outputs( namespace: Optional[str] = None, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: """ This method allows one to automatically add the ouputs from another Process to this ProcessSpec. @@ -249,7 +250,7 @@ def _expose_ports( namespace: Optional[str], exclude: Optional[Sequence[str]], include: Optional[Sequence[str]], - namespace_options: Optional[dict] = None + namespace_options: Optional[dict] = None, ) -> None: # pylint: disable=too-many-arguments """ Expose ports from a source PortNamespace of the ProcessSpec of a Process class into the destination diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index 3407412d..3d25f696 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from enum import Enum import sys import traceback +from enum import Enum from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast @@ -64,7 +64,6 @@ class Command(persistence.Savable): @auto_persist('msg') class Kill(Command): - def __init__(self, msg: Optional[Any] = None): super().__init__() self.msg = msg @@ -76,7 +75,6 @@ class Pause(Command): @auto_persist('msg', 'data') class Wait(Command): - def __init__( self, continue_fn: Optional[Callable[..., Any]] = None, msg: Optional[Any] = None, data: Optional[Any] = None ): @@ -88,7 +86,6 @@ def __init__( @auto_persist('result') class Stop(Command): - def __init__(self, result: Any, successful: bool) -> None: super().__init__() self.result = result @@ -127,6 +124,7 @@ class ProcessState(Enum): """ The possible states that a :class:`~plumpy.processes.Process` can be in. """ + CREATED: str = 'created' RUNNING: str = 'running' WAITING: str = 'waiting' @@ -137,7 +135,6 @@ class ProcessState(Enum): @auto_persist('in_state') class State(state_machine.State, persistence.Savable): - @property def process(self) -> state_machine.StateMachine: """ @@ -183,7 +180,11 @@ def execute(self) -> state_machine.State: class Running(State): LABEL = ProcessState.RUNNING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.FINISHED, + ProcessState.KILLED, + ProcessState.EXCEPTED, } RUN_FN = 'run_fn' # The key used to store the function to run @@ -267,7 +268,11 @@ def _action_command(self, command: Union[Kill, Stop, Wait, Continue]) -> State: class Waiting(State): LABEL = ProcessState.WAITING ALLOWED = { - ProcessState.RUNNING, ProcessState.WAITING, ProcessState.KILLED, ProcessState.EXCEPTED, ProcessState.FINISHED + ProcessState.RUNNING, + ProcessState.WAITING, + ProcessState.KILLED, + ProcessState.EXCEPTED, + ProcessState.FINISHED, } DONE_CALLBACK = 'DONE_CALLBACK' @@ -285,7 +290,7 @@ def __init__( process: 'Process', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - data: Optional[Any] = None + data: Optional[Any] = None, ) -> None: super().__init__(process) self.done_callback = done_callback @@ -370,9 +375,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi self.exception = yaml.load(saved_state[self.EXC_VALUE], Loader=Loader) if _HAS_TBLIB: try: - self.traceback = \ - tblib.Traceback.from_string(saved_state[self.TRACEBACK], - strict=False) + self.traceback = tblib.Traceback.from_string(saved_state[self.TRACEBACK], strict=False) except KeyError: self.traceback = None else: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index b6e14ad9..d380bf3a 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The main Process module""" + import abc import asyncio import contextlib @@ -10,6 +11,8 @@ import re import sys import time +import uuid +import warnings from types import TracebackType from typing import ( Any, @@ -26,17 +29,15 @@ Union, cast, ) -import uuid -import warnings try: from aiocontextvars import ContextVar except ModuleNotFoundError: from contextvars import ContextVar -from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed import kiwipy import yaml +from aio_pika.exceptions import ChannelInvalidStateError, ConnectionClosed from . import events, exceptions, futures, persistence, ports, process_comms, process_states, utils from .base import state_machine @@ -62,6 +63,7 @@ class BundleKeys: See :meth:`plumpy.processes.Process.save_instance_state` and :meth:`plumpy.processes.Process.load_instance_state`. """ + # pylint: disable=too-few-public-methods INPUTS_RAW = 'INPUTS_RAW' INPUTS_PARSED = 'INPUTS_PARSED' @@ -75,7 +77,7 @@ class ProcessStateMachineMeta(abc.ABCMeta, state_machine.StateMachineMeta): # Make ProcessStateMachineMeta instances (classes) YAML - able yaml.representer.Representer.add_representer( ProcessStateMachineMeta, - yaml.representer.Representer.represent_name # type: ignore[arg-type] + yaml.representer.Representer.represent_name, # type: ignore[arg-type] ) @@ -167,7 +169,7 @@ def get_states(cls) -> Sequence[Type[process_states.State]]: state_classes = cls.get_state_classes() return ( state_classes[process_states.ProcessState.CREATED], - *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED] + *[state for state in state_classes.values() if state.LABEL != process_states.ProcessState.CREATED], ) @classmethod @@ -179,7 +181,7 @@ def get_state_classes(cls) -> Dict[Hashable, Type[process_states.State]]: process_states.ProcessState.WAITING: process_states.Waiting, process_states.ProcessState.FINISHED: process_states.Finished, process_states.ProcessState.EXCEPTED: process_states.Excepted, - process_states.ProcessState.KILLED: process_states.Killed + process_states.ProcessState.KILLED: process_states.Killed, } @classmethod @@ -256,7 +258,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: """ The signature of the constructor should not be changed by subclassing processes. @@ -278,8 +280,9 @@ def __init__( self._setup_event_hooks() self._status: Optional[str] = None # May hold a current status message - self._pre_paused_status: Optional[ - str] = None # Save status when a pause message replaces it, such that it can be restored + self._pre_paused_status: Optional[str] = ( + None # Save status when a pause message replaces it, such that it can be restored + ) self._paused = None # Input/output @@ -331,12 +334,13 @@ def try_killing(future: futures.Future) -> None: def _setup_event_hooks(self) -> None: """Set the event hooks to process, when it is created or loaded(recreated).""" event_hooks = { - state_machine.StateEventHook.ENTERING_STATE: - lambda _s, _h, state: self.on_entering(cast(process_states.State, state)), - state_machine.StateEventHook.ENTERED_STATE: - lambda _s, _h, from_state: self.on_entered(cast(Optional[process_states.State], from_state)), - state_machine.StateEventHook.EXITING_STATE: - lambda _s, _h, _state: self.on_exiting() + state_machine.StateEventHook.ENTERING_STATE: lambda _s, _h, state: self.on_entering( + cast(process_states.State, state) + ), + state_machine.StateEventHook.ENTERED_STATE: lambda _s, _h, from_state: self.on_entered( + cast(Optional[process_states.State], from_state) + ), + state_machine.StateEventHook.EXITING_STATE: lambda _s, _h, _state: self.on_exiting(), } for hook, callback in event_hooks.items(): self.add_state_event_callback(hook, callback) @@ -356,7 +360,7 @@ def pid(self) -> Optional[PID_TYPE]: @property def uuid(self) -> Optional[uuid.UUID]: - """Return the UUID of the process """ + """Return the UUID of the process""" return self._uuid @property @@ -421,7 +425,7 @@ def launch( process_class: Type['Process'], inputs: Optional[dict] = None, pid: Optional[PID_TYPE] = None, - logger: Optional[logging.Logger] = None + logger: Optional[logging.Logger] = None, ) -> 'Process': """Start running the nested process. @@ -663,7 +667,7 @@ def add_process_listener(self, listener: ProcessListener) -> None: the specific state condition. """ - assert (listener != self), 'Cannot listen to yourself!' # type: ignore + assert listener != self, 'Cannot listen to yourself!' # type: ignore self._event_helper.add_listener(listener) def remove_process_listener(self, listener: ProcessListener) -> None: @@ -926,8 +930,9 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An # Didn't match any known intents raise RuntimeError('Unknown intent') - def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, - correlation_id: Any) -> Optional[kiwipy.Future]: + def broadcast_receive( + self, _comm: kiwipy.Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any + ) -> Optional[kiwipy.Future]: """ Coroutine called when the process receives a message from the communicator @@ -1044,7 +1049,7 @@ def pause(self, msg: Union[str, None] = None) -> Union[bool, futures.Cancellable return self._do_pause(msg) def _do_pause(self, state_msg: Optional[str], next_state: Optional[process_states.State] = None) -> bool: - """ Carry out the pause procedure, optionally transitioning to the next state first""" + """Carry out the pause procedure, optionally transitioning to the next state first""" try: if next_state is not None: self.transition_to(next_state) @@ -1091,7 +1096,7 @@ def _set_interrupt_action(self, new_action: Optional[futures.CancellableAction]) self._interrupt_action = new_action def _set_interrupt_action_from_exception(self, interrupt_exception: process_states.Interruption) -> None: - """ Set an interrupt action from the corresponding interrupt exception """ + """Set an interrupt action from the corresponding interrupt exception""" action = self._create_interrupt_action(interrupt_exception) self._set_interrupt_action(action) @@ -1285,7 +1290,7 @@ def out(self, output_port: str, value: Any) -> None: if namespace: port_namespace = cast( ports.PortNamespace, - self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True) + self.spec().outputs.get_port(namespace_separator.join(namespace), create_dynamically=True), ) else: port_namespace = self.spec().outputs @@ -1341,9 +1346,11 @@ def get_status_info(self, out_status_info: dict) -> None: :param out_status_info: the old status """ - out_status_info.update({ - 'ctime': self.creation_time, - 'paused': self.paused, - 'process_string': str(self), - 'state': str(self.state), - }) + out_status_info.update( + { + 'ctime': self.creation_time, + 'paused': self.paused, + 'process_string': str(self), + 'state': str(self.state), + } + ) diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index 4eab8efe..b8a8e8be 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -1,20 +1,30 @@ # -*- coding: utf-8 -*- import asyncio -from collections import deque -from collections.abc import Mapping import functools import importlib import inspect import logging import types -from typing import Set # pylint: disable=unused-import -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type +from collections import deque +from collections.abc import Mapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Hashable, + Iterator, + List, + MutableMapping, + Optional, + Tuple, + Type, +) from . import lang from .settings import check_override, check_protected if TYPE_CHECKING: - from .process_listener import ProcessListener # pylint: disable=cyclic-import + pass # pylint: disable=cyclic-import __all__ = ['AttributesDict'] @@ -67,7 +77,6 @@ def __hash__(self) -> int: class AttributesFrozendict(Frozendict): - def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._initialised: bool = True diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 90e35482..e675a6a0 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -22,7 +22,6 @@ class WorkChainSpec(processes.ProcessSpec): - def __init__(self) -> None: super().__init__() self._outline: Optional[Union['_Instruction', '_FunctionCall']] = None @@ -55,14 +54,14 @@ def get_outline(self) -> Union['_Instruction', '_FunctionCall']: @persistence.auto_persist('_awaiting') class Waiting(process_states.Waiting): - """ Overwrite the waiting state""" + """Overwrite the waiting state""" def __init__( self, process: 'WorkChain', done_callback: Optional[Callable[..., Any]], msg: Optional[str] = None, - awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None + awaiting: Optional[Dict[Union[asyncio.Future, processes.Process], str]] = None, ) -> None: super().__init__(process, done_callback, msg, awaiting) self._awaiting: Dict[asyncio.Future, str] = {} @@ -97,6 +96,7 @@ class WorkChain(mixins.ContextMixin, processes.Process): A WorkChain is a series of instructions carried out with the ability to save state in between. """ + _spec_class = WorkChainSpec _STEPPER_STATE = 'stepper_state' _CONTEXT = 'CONTEXT' @@ -113,7 +113,7 @@ def __init__( pid: Optional[PID_TYPE] = None, logger: Optional[logging.Logger] = None, loop: Optional[asyncio.AbstractEventLoop] = None, - communicator: Optional[kiwipy.Communicator] = None + communicator: Optional[kiwipy.Communicator] = None, ) -> None: super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator) self._stepper: Optional[Stepper] = None @@ -169,7 +169,6 @@ def _do_step(self) -> Any: finished, return_value = True, exception.exit_code if not finished and (return_value is None or isinstance(return_value, ToContext)): - if isinstance(return_value, ToContext): self.to_context(**return_value) @@ -182,7 +181,6 @@ def _do_step(self) -> Any: class Stepper(persistence.Savable, metaclass=abc.ABCMeta): - def __init__(self, workchain: 'WorkChain') -> None: self._workchain = workchain @@ -210,11 +208,11 @@ class _Instruction(metaclass=abc.ABCMeta): @abc.abstractmethod def create_stepper(self, workchain: 'WorkChain') -> Stepper: - """ Create a new stepper for this instruction """ + """Create a new stepper for this instruction""" @abc.abstractmethod def recreate_stepper(self, saved_state: SAVED_STATE_TYPE, workchain: 'WorkChain') -> Stepper: - """ Recreate a stepper from a previously saved state """ + """Recreate a stepper from a previously saved state""" def __str__(self) -> str: return str(self.get_description()) @@ -229,7 +227,6 @@ def get_description(self) -> Any: class _FunctionStepper(Stepper): - def __init__(self, workchain: 'WorkChain', fn: WC_COMMAND_TYPE): super().__init__(workchain) self._fn = fn @@ -250,7 +247,6 @@ def __str__(self) -> str: class _FunctionCall(_Instruction): - def __init__(self, func: WC_COMMAND_TYPE) -> None: try: args = inspect.getfullargspec(func)[0] @@ -282,7 +278,6 @@ def get_description(self) -> str: @persistence.auto_persist('_pos') class _BlockStepper(Stepper): - def __init__(self, block: Sequence[_Instruction], workchain: 'WorkChain') -> None: super().__init__(workchain) self._block = block @@ -392,10 +387,12 @@ def is_true(self, workflow: 'WorkChain') -> bool: if not hasattr(result, '__bool__'): import warnings + warnings.warn( f'The conditional predicate `{self._predicate.__name__}` returned `{result}` which is not boolean-like.' ' The return value should be `True` or `False` or implement the `__bool__` method. This behavior is ' - 'deprecated and will soon start raising an exception.', UserWarning + 'deprecated and will soon start raising an exception.', + UserWarning, ) return result @@ -411,7 +408,6 @@ def __str__(self) -> str: @persistence.auto_persist('_pos') class _IfStepper(Stepper): - def __init__(self, if_instruction: '_If', workchain: 'WorkChain') -> None: super().__init__(workchain) self._if_instruction = if_instruction @@ -467,7 +463,6 @@ def __str__(self) -> str: class _If(_Instruction, collections.abc.Sequence): - def __init__(self, condition: PREDICATE_TYPE) -> None: super().__init__() self._ifs: List[_Conditional] = [_Conditional(self, condition, label=if_.__name__)] @@ -520,7 +515,6 @@ def get_description(self) -> Mapping[str, Any]: class _WhileStepper(Stepper): - def __init__(self, while_instruction: '_While', workchain: 'WorkChain') -> None: super().__init__(workchain) self._while_instruction = while_instruction @@ -563,7 +557,6 @@ def __str__(self) -> str: class _While(_Conditional, _Instruction, collections.abc.Sequence): - def __init__(self, predicate: PREDICATE_TYPE) -> None: super().__init__(self, predicate, label=while_.__name__) @@ -586,14 +579,12 @@ def get_description(self) -> Dict[str, Any]: class _PropagateReturn(BaseException): - def __init__(self, exit_code: Optional[EXIT_CODE_TYPE]) -> None: super().__init__() self.exit_code = exit_code class _ReturnStepper(Stepper): - def __init__(self, return_instruction: '_Return', workchain: 'WorkChain') -> None: super().__init__(workchain) self._return_instruction = return_instruction diff --git a/test/base/test_statemachine.py b/test/base/test_statemachine.py index 72fed261..5a8deb87 100644 --- a/test/base/test_statemachine.py +++ b/test/base/test_statemachine.py @@ -25,7 +25,7 @@ def __init__(self, player, track): super().__init__(player) self.track = track self._last_time = None - self._played = 0. + self._played = 0.0 def __str__(self): if self.in_state: @@ -55,8 +55,7 @@ class Paused(state_machine.State): TRANSITIONS = {STOP: STOPPED} def __init__(self, player, playing_state): - assert isinstance(playing_state, Playing), \ - 'Must provide the playing state to pause' + assert isinstance(playing_state, Playing), 'Must provide the playing state to pause' super().__init__(player) self.playing_state = playing_state @@ -117,14 +116,13 @@ def stop(self): class TestStateMachine(unittest.TestCase): - def test_basic(self): cd_player = CdPlayer() self.assertEqual(cd_player.state, STOPPED) cd_player.play('Eminem - The Real Slim Shady') self.assertEqual(cd_player.state, PLAYING) - time.sleep(1.) + time.sleep(1.0) cd_player.pause() self.assertEqual(cd_player.state, PAUSED) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 9aa0237b..d62b1422 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -5,7 +5,6 @@ class Root: - @utils.super_check def method(self): pass @@ -15,19 +14,16 @@ def do(self): class DoCall(Root): - def method(self): super().method() class DontCall(Root): - def method(self): pass class TestSuperCheckMixin(unittest.TestCase): - def test_do_call(self): DoCall().do() @@ -36,9 +32,7 @@ def test_dont_call(self): DontCall().do() def dont_call_middle(self): - class ThirdChild(DontCall): - def method(self): super().method() diff --git a/test/conftest.py b/test/conftest.py index 43555586..c70088fa 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,4 +5,5 @@ @pytest.fixture(scope='session') def set_event_loop_policy(): from plumpy import set_event_loop_policy + set_event_loop_policy() diff --git a/test/persistence/test_inmemory.py b/test/persistence/test_inmemory.py index bc03f88b..3de2f890 100644 --- a/test/persistence/test_inmemory.py +++ b/test/persistence/test_inmemory.py @@ -1,13 +1,13 @@ # -*- coding: utf-8 -*- import asyncio -from test.utils import ProcessWithCheckpoint import unittest import plumpy +from test.utils import ProcessWithCheckpoint + class TestInMemoryPersister(unittest.TestCase): - def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint @@ -24,8 +24,7 @@ def test_save_load_roundtrip(self): recreated = bundle.unbundle(load_context) def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -43,8 +42,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -64,8 +62,7 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -87,8 +84,7 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -116,8 +112,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/persistence/test_pickle.py b/test/persistence/test_pickle.py index 19e4f52a..0046792d 100644 --- a/test/persistence/test_pickle.py +++ b/test/persistence/test_pickle.py @@ -6,13 +6,12 @@ if getattr(tempfile, 'TemporaryDirectory', None) is None: from backports import tempfile -from test.utils import ProcessWithCheckpoint - import plumpy +from test.utils import ProcessWithCheckpoint -class TestPicklePersister(unittest.TestCase): +class TestPicklePersister(unittest.TestCase): def test_save_load_roundtrip(self): """ Test the plumpy.PicklePersister by taking a dummpy process, saving a checkpoint @@ -30,8 +29,7 @@ def test_save_load_roundtrip(self): recreated = bundle.unbundle(load_context) def test_get_checkpoints_without_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -50,8 +48,7 @@ def test_get_checkpoints_without_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_checkpoints_with_tags(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() tag_a = 'tag_a' @@ -72,8 +69,7 @@ def test_get_checkpoints_with_tags(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_get_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -96,8 +92,7 @@ def test_get_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_process_checkpoints(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() @@ -126,8 +121,7 @@ def test_delete_process_checkpoints(self): self.assertSetEqual(set(retrieved_checkpoints), set(checkpoints)) def test_delete_checkpoint(self): - """ - """ + """ """ process_a = ProcessWithCheckpoint() process_b = ProcessWithCheckpoint() diff --git a/test/rmq/test_communicator.py b/test/rmq/test_communicator.py index 5cedd38d..1ef13a8e 100644 --- a/test/rmq/test_communicator.py +++ b/test/rmq/test_communicator.py @@ -1,17 +1,17 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.rmq.communicator` module.""" + import asyncio import functools import shutil import tempfile import uuid -from kiwipy import BroadcastFilter, rmq +import plumpy import pytest import shortuuid import yaml - -import plumpy +from kiwipy import BroadcastFilter, rmq from plumpy import communications, process_comms from .. import utils @@ -38,7 +38,7 @@ def loop_communicator(): message_exchange=message_exchange, task_exchange=task_exchange, task_queue=task_queue, - decoder=functools.partial(yaml.load, Loader=yaml.Loader) + decoder=functools.partial(yaml.load, Loader=yaml.Loader), ) loop = asyncio.get_event_loop() @@ -69,12 +69,9 @@ async def test_broadcast(self, loop_communicator): def get_broadcast(_comm, body, sender, subject, correlation_id): assert loop is asyncio.get_event_loop() - broadcast_future.set_result({ - 'body': body, - 'sender': sender, - 'subject': subject, - 'correlation_id': correlation_id - }) + broadcast_future.set_result( + {'body': body, 'sender': sender, 'subject': subject, 'correlation_id': correlation_id} + ) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send(**BROADCAST) @@ -84,7 +81,6 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): @pytest.mark.asyncio async def test_broadcast_filter(self, loop_communicator): - broadcast_future = plumpy.Future() loop = asyncio.get_event_loop() @@ -98,12 +94,7 @@ def get_broadcast(_comm, body, sender, subject, correlation_id): loop_communicator.add_broadcast_subscriber(BroadcastFilter(ignore_broadcast, subject='other')) loop_communicator.add_broadcast_subscriber(get_broadcast) loop_communicator.broadcast_send( - **{ - 'body': 'present', - 'sender': 'Martin', - 'subject': 'sup', - 'correlation_id': 420 - } + **{'body': 'present', 'sender': 'Martin', 'subject': 'sup', 'correlation_id': 420} ) result = await broadcast_future @@ -145,7 +136,6 @@ def get_task(_comm, msg): class TestTaskActions: - @pytest.mark.asyncio async def test_launch(self, loop_communicator, async_controller, persister): # Let the process run to the end @@ -157,7 +147,7 @@ async def test_launch(self, loop_communicator, async_controller, persister): @pytest.mark.asyncio async def test_launch_nowait(self, loop_communicator, async_controller, persister): - """ Testing launching but don't wait, just get the pid """ + """Testing launching but don't wait, just get the pid""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.launch_process(utils.DummyProcess, nowait=True) @@ -165,7 +155,7 @@ async def test_launch_nowait(self, loop_communicator, async_controller, persiste @pytest.mark.asyncio async def test_execute_action(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) result = await async_controller.execute_process(utils.DummyProcessWithOutput) @@ -173,7 +163,7 @@ async def test_execute_action(self, loop_communicator, async_controller, persist @pytest.mark.asyncio async def test_execute_action_nowait(self, loop_communicator, async_controller, persister): - """ Test the process execute action """ + """Test the process execute action""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) pid = await async_controller.execute_process(utils.DummyProcessWithOutput, nowait=True) @@ -197,7 +187,7 @@ async def test_launch_many(self, loop_communicator, async_controller, persister) @pytest.mark.asyncio async def test_continue(self, loop_communicator, async_controller, persister): - """ Test continuing a saved process """ + """Test continuing a saved process""" loop = asyncio.get_event_loop() loop_communicator.add_task_subscriber(plumpy.ProcessLauncher(loop, persister=persister)) process = utils.DummyProcessWithOutput() diff --git a/test/rmq/test_process_comms.py b/test/rmq/test_process_comms.py index 6afccf46..97a949ab 100644 --- a/test/rmq/test_process_comms.py +++ b/test/rmq/test_process_comms.py @@ -2,13 +2,12 @@ import asyncio import kiwipy -from kiwipy import rmq +import plumpy +import plumpy.communications import pytest import shortuuid - -import plumpy +from kiwipy import rmq from plumpy import process_comms -import plumpy.communications from .. import utils @@ -43,7 +42,6 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator): class TestRemoteProcessController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, async_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) @@ -122,7 +120,6 @@ def on_broadcast_receive(**msg): class TestRemoteProcessThreadController: - @pytest.mark.asyncio async def test_pause(self, thread_communicator, sync_controller): proc = utils.WaitForSignalProcess(communicator=thread_communicator) diff --git a/test/test_communications.py b/test/test_communications.py index f82036bd..1691cbd7 100644 --- a/test/test_communications.py +++ b/test/test_communications.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.communications` module.""" -from kiwipy import CommunicatorHelper -import pytest +import pytest +from kiwipy import CommunicatorHelper from plumpy.communications import LoopCommunicator @@ -14,7 +14,6 @@ def __call__(self): class Communicator(CommunicatorHelper): - def task_send(self, task, no_reply=False): pass diff --git a/test/test_events.py b/test/test_events.py index e6260f1d..1dc2d325 100644 --- a/test/test_events.py +++ b/test/test_events.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.events` module.""" + import asyncio import pathlib import pytest - from plumpy import PlumpyEventLoopPolicy, new_event_loop, reset_event_loop_policy, set_event_loop, set_event_loop_policy diff --git a/test/test_expose.py b/test/test_expose.py index 1a495727..8ca191e3 100644 --- a/test/test_expose.py +++ b/test/test_expose.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- -from test.utils import NewLoopProcess import unittest from plumpy.ports import PortNamespace from plumpy.process_spec import ProcessSpec from plumpy.processes import Process +from test.utils import NewLoopProcess -class TestExposeProcess(unittest.TestCase): +class TestExposeProcess(unittest.TestCase): def setUp(self): super().setUp() @@ -16,7 +16,6 @@ def validator_function(input, port): pass class BaseNamespaceProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -27,7 +26,6 @@ def define(cls, spec): spec.inputs['namespace'].validator = validator_function class BaseProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -37,7 +35,6 @@ def define(cls, spec): spec.inputs.valid_type = str class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -78,14 +75,12 @@ def test_expose_dynamic(self): """Test that exposing a dynamic namespace remains dynamic.""" class Lower(Process): - @classmethod def define(cls, spec): super(Lower, cls).define(spec) spec.input_namespace('foo', dynamic=True) class Upper(Process): - @classmethod def define(cls, spec): super(Upper, cls).define(spec) @@ -150,7 +145,6 @@ def test_expose_exclude(self): BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -168,7 +162,6 @@ def test_expose_include(self): BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -186,7 +179,6 @@ def test_expose_exclude_include_mutually_exclusive(self): BaseProcess = self.BaseProcess class ExcludeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -228,7 +220,7 @@ def validator_function(input, port): namespace=None, exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -283,7 +275,7 @@ def validator_function(input, port): 'dynamic': False, 'default': None, 'help': None, - } + }, ) # Verify that all the ports are there @@ -330,7 +322,7 @@ def validator_function(input, port): namespace='namespace', exclude=(), include=None, - namespace_options={} + namespace_options={}, ) # Verify that all the ports are there @@ -365,7 +357,7 @@ def test_expose_ports_namespace_options_non_existent(self): include=None, namespace_options={ 'non_existent': None, - } + }, ) def test_expose_nested_include_top_level(self): @@ -373,7 +365,6 @@ def test_expose_nested_include_top_level(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -387,7 +378,6 @@ def test_expose_nested_include_namespace(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -403,7 +393,6 @@ def test_expose_nested_include_namespace_sub(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -419,7 +408,6 @@ def test_expose_nested_include_combination(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -435,7 +423,6 @@ def test_expose_nested_exclude_top_level(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -451,7 +438,6 @@ def test_expose_nested_exclude_namespace(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -465,7 +451,6 @@ def test_expose_nested_exclude_namespace_sub(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -481,7 +466,6 @@ def test_expose_nested_exclude_combination(self): BaseNamespaceProcess = self.BaseNamespaceProcess class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -504,7 +488,6 @@ def test_expose_exclude_port_with_validator(self): """ class BaseProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -520,10 +503,9 @@ def validator(cls, value, ctx): return None if not isinstance(value['a'], str): - return f'value for input `a` should be a str, but got: {type(value["a"])}' + return f'value for input `a` should be a str, but got: {type(value['a'])}' class ExposeProcess(NewLoopProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_lang.py b/test/test_lang.py index 13136530..42125394 100644 --- a/test/test_lang.py +++ b/test/test_lang.py @@ -5,7 +5,6 @@ class A: - def __init__(self): self._a = None @@ -28,21 +27,18 @@ def testA(self): class B(A): - def testB(self): self.protected_fn() self.protected_property class C(B): - def testC(self): self.protected_fn() self.protected_property class TestProtected(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -79,7 +75,6 @@ def test_incorrect_usage(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder: - @protected(check=True) @property def a(self): @@ -87,13 +82,11 @@ def a(self): class Superclass: - def test(self): pass class TestOverride(TestCase): - def test_free_function(self): with self.assertRaises(RuntimeError): @@ -102,9 +95,7 @@ def some_func(): pass def test_correct_usage(self): - class Derived(Superclass): - @override(check=True) def test(self): return True @@ -115,7 +106,6 @@ class Middle(Superclass): pass class Next(Middle): - @override(check=True) def test(self): return True @@ -123,9 +113,7 @@ def test(self): self.assertTrue(Next().test()) def test_incorrect_usage(self): - class Derived: - @override(check=True) def test(self): pass @@ -136,7 +124,6 @@ def test(self): with self.assertRaises(RuntimeError): class TestWrongDecoratorOrder(Superclass): - @override(check=True) @property def test(self): diff --git a/test/test_loaders.py b/test/test_loaders.py index 3058b77c..75fd2848 100644 --- a/test/test_loaders.py +++ b/test/test_loaders.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- """Tests for the :mod:`plumpy.loaders` module.""" -import pytest import plumpy +import pytest class DummyClass: """Dummy class for testing.""" + pass @@ -38,11 +39,12 @@ def test_default_object_roundtrip(): @pytest.mark.parametrize( - 'identifier, match', ( + 'identifier, match', + ( ('plumpy.non_existing_module.SomeClass', r'identifier `.*` has an invalid format.'), ('plumpy.non_existing_module:SomeClass', r'module `.*` from identifier `.*` could not be loaded.'), ('plumpy.loaders:NonExistingClass', r'object `.*` form identifier `.*` could not be loaded.'), - ) + ), ) def test_default_object_loader_load_object_except(identifier, match): """Test the :meth:`plumpy.DefaultObjectLoader.load_object` when it is expected to raise.""" diff --git a/test/test_persistence.py b/test/test_persistence.py index 2c9cf4f9..0ec3d5ab 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -2,9 +2,8 @@ import asyncio import unittest -import yaml - import plumpy +import yaml from . import utils @@ -15,7 +14,6 @@ class SaveEmpty(plumpy.Savable): @plumpy.auto_persist('test', 'test_method') class Save1(plumpy.Savable): - def __init__(self): self.test = 'sup yp' self.test_method = self.m @@ -26,13 +24,11 @@ def m(): @plumpy.auto_persist('test') class Save(plumpy.Savable): - def __init__(self): self.test = Save1() class TestSavable(unittest.TestCase): - def test_empty_savable(self): self._save_round_trip(SaveEmpty()) @@ -79,9 +75,8 @@ def _save_round_trip_with_loader(self, savable): class TestBundle(unittest.TestCase): - def test_bundle_load_context(self): - """ Check that the loop from the load context is used """ + """Check that the loop from the load context is used""" loop1 = asyncio.get_event_loop() proc = utils.DummyProcess(loop=loop1) bundle = plumpy.Bundle(proc) diff --git a/test/test_port.py b/test/test_port.py index ab9b51a6..1809f326 100644 --- a/test/test_port.py +++ b/test/test_port.py @@ -7,7 +7,6 @@ class TestPort(TestCase): - def test_required(self): spec = Port('required_value', required=True) @@ -21,7 +20,6 @@ def test_validate(self): self.assertIsNotNone(spec.validate('a')) def test_validator(self): - def validate(value, port): assert isinstance(port, Port) if not isinstance(value, int): @@ -45,7 +43,6 @@ def validate(value, port): class TestInputPort(TestCase): - def test_default(self): """Test the default value property for the InputPort.""" port = InputPort('test', default=5) @@ -86,7 +83,6 @@ def test_lambda_default(self): class TestOutputPort(TestCase): - def test_default(self): """ Test the default value property for the InputPort @@ -108,7 +104,6 @@ def validator(value, port): class TestPortNamespace(TestCase): - BASE_PORT_NAME = 'port' BASE_PORT_NAMESPACE_NAME = 'port' @@ -299,7 +294,7 @@ def test_port_namespace_validate(self): # Check the breadcrumbs are correct self.assertEqual( validation_error.port, - self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')) + self.port_namespace.NAMESPACE_SEPARATOR.join((self.BASE_PORT_NAMESPACE_NAME, 'sub', 'space', 'output')), ) def test_port_namespace_required(self): diff --git a/test/test_process_comms.py b/test/test_process_comms.py index 6d3d335c..5349a90d 100644 --- a/test/test_process_comms.py +++ b/test/test_process_comms.py @@ -1,23 +1,17 @@ # -*- coding: utf-8 -*- -import asyncio -from test import utils -import unittest - -from kiwipy import rmq +import plumpy import pytest +from plumpy import process_comms -import plumpy -from plumpy import communications, process_comms +from test import utils class Process(plumpy.Process): - def run(self): pass class CustomObjectLoader(plumpy.DefaultObjectLoader): - def load_object(self, identifier): if identifier == 'jimmy': return Process @@ -49,7 +43,7 @@ async def test_continue(): @pytest.mark.asyncio async def test_loader_is_used(): - """ Make sure that the provided class loader is used by the process launcher """ + """Make sure that the provided class loader is used by the process launcher""" loader = CustomObjectLoader() proc = Process() persister = plumpy.InMemoryPersister(loader=loader) diff --git a/test/test_process_spec.py b/test/test_process_spec.py index 443f7a64..a43ad936 100644 --- a/test/test_process_spec.py +++ b/test/test_process_spec.py @@ -10,7 +10,6 @@ class StrSubtype(str): class TestProcessSpec(TestCase): - def setUp(self): self.spec = ProcessSpec() diff --git a/test/test_processes.py b/test/test_processes.py index 0cb4161b..a7adefc7 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -1,20 +1,20 @@ # -*- coding: utf-8 -*- """Process tests""" + import asyncio import enum -from test import utils import unittest import kiwipy -import pytest - import plumpy +import pytest from plumpy import BundleKeys, Process, ProcessState from plumpy.utils import AttributesFrozendict +from test import utils -class ForgetToCallParent(plumpy.Process): +class ForgetToCallParent(plumpy.Process): def __init__(self, forget_on): super().__init__() self.forget_on = forget_on @@ -42,9 +42,7 @@ def on_kill(self, msg): @pytest.mark.asyncio async def test_process_scope(): - class ProcessTaskInterleave(plumpy.Process): - async def task(self, steps: list): steps.append(f'[{self.pid}] started') assert plumpy.Process.current() is self @@ -64,7 +62,6 @@ async def task(self, steps: list): class TestProcess(unittest.TestCase): - def test_spec(self): """ Check that the references to specs are doing the right thing... @@ -82,12 +79,10 @@ class Proc(utils.DummyProcess): self.assertIs(p.spec(), Proc.spec()) def test_dynamic_inputs(self): - class NoDynamic(Process): pass class WithDynamic(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -100,9 +95,7 @@ def define(cls, spec): proc.execute() def test_inputs(self): - class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -122,7 +115,6 @@ def test_raw_inputs(self): """ class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -138,9 +130,7 @@ def define(cls, spec): self.assertDictEqual(dict(process.raw_inputs), {'a': 5, 'nested': {'a': 'value'}}) def test_inputs_default(self): - class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -199,7 +189,6 @@ def test_inputs_default_that_evaluate_to_false(self): for def_val in (True, False, 0, 1): class Proc(utils.DummyProcess): - @classmethod def define(cls, spec): super().define(spec) @@ -214,7 +203,6 @@ def test_nested_namespace_defaults(self): """Process with a default in a nested namespace should be created, even if top level namespace not supplied.""" class SomeProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -229,7 +217,6 @@ def test_raise_in_define(self): """Process which raises in its 'define' method. Check that the spec is not set.""" class BrokenProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -293,12 +280,11 @@ def test_run_kill(self): proc.execute() def test_get_description(self): - class ProcWithoutSpec(Process): pass class ProcWithSpec(Process): - """ Process with a spec and a docstring """ + """Process with a spec and a docstring""" @classmethod def define(cls, spec): @@ -324,9 +310,7 @@ def define(cls, spec): self.assertIsInstance(desc_with_spec['description'], str) def test_logging(self): - class LoggerTester(Process): - def run(self, **kwargs): self.logger.info('Test') @@ -438,7 +422,6 @@ async def async_test(): self.assertEqual(proc.state, ProcessState.FINISHED) def test_kill_in_run(self): - class KillProcess(Process): after_kill = False @@ -456,9 +439,7 @@ def run(self, **kwargs): self.assertEqual(proc.state, ProcessState.KILLED) def test_kill_when_paused_in_run(self): - class PauseProcess(Process): - def run(self, **kwargs): self.pause() self.kill() @@ -510,9 +491,7 @@ def test_run_multiple(self): self.assertDictEqual(proc_class.EXPECTED_OUTPUTS, result) def test_invalid_output(self): - class InvalidOutput(plumpy.Process): - def run(self): self.out('invalid', 5) @@ -536,7 +515,6 @@ def test_unsuccessful_result(self): ERROR_CODE = 256 class Proc(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -550,11 +528,10 @@ def run(self): self.assertEqual(proc.result(), ERROR_CODE) def test_pause_in_process(self): - """ Test that we can pause and cancel that by playing within the process """ + """Test that we can pause and cancel that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -574,12 +551,11 @@ def run(self): self.assertEqual(plumpy.ProcessState.FINISHED, proc.state) def test_pause_play_in_process(self): - """ Test that we can pause and play that by playing within the process """ + """Test that we can pause and play that by playing within the process""" test_case = self class TestPausePlay(plumpy.Process): - def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -596,7 +572,6 @@ def test_process_stack(self): test_case = self class StackTest(plumpy.Process): - def run(self): test_case.assertIs(self, Process.current()) @@ -613,7 +588,6 @@ def test_nested(process): expect_true.append(process == Process.current()) class StackTest(plumpy.Process): - def run(self): # TODO: unexpected behaviour here # if assert error happend here not raise @@ -623,7 +597,6 @@ def run(self): test_nested(self) class ParentProcess(plumpy.Process): - def run(self): expect_true.append(self == Process.current()) StackTest().execute() @@ -646,21 +619,17 @@ def test_process_nested(self): """ class StackTest(plumpy.Process): - def run(self): pass class ParentProcess(plumpy.Process): - def run(self): StackTest().execute() ParentProcess().execute() def test_call_soon(self): - class CallSoon(plumpy.Process): - def run(self): self.call_soon(self.do_except) @@ -680,7 +649,6 @@ def test_exception_during_on_entered(self): """Test that an exception raised during ``on_entered`` will cause the process to be excepted.""" class RaisingProcess(Process): - def on_entered(self, from_state): if from_state is not None and from_state.label == ProcessState.RUNNING: raise RuntimeError('exception during on_entered') @@ -696,9 +664,7 @@ def on_entered(self, from_state): assert str(process.exception()) == 'exception during on_entered' def test_exception_during_run(self): - class RaisingProcess(Process): - def run(self): raise RuntimeError('exception during run') @@ -862,7 +828,7 @@ async def async_test(): loop.run_until_complete(async_test()) def test_wait_save_continue(self): - """ Test that process saved while in WAITING state restarts correctly when loaded """ + """Test that process saved while in WAITING state restarts correctly when loaded""" loop = asyncio.get_event_loop() proc = utils.WaitForSignalProcess() @@ -905,7 +871,6 @@ def _check_round_trip(self, proc1): class TestProcessNamespace(unittest.TestCase): - def test_namespaced_process(self): """ Test that inputs in nested namespaces are properly validated and the returned @@ -913,7 +878,6 @@ def test_namespaced_process(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -938,7 +902,6 @@ def test_namespaced_process_inputs(self): """ class NameSpacedProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -964,7 +927,6 @@ def test_namespaced_process_dynamic(self): namespace = 'name.space' class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -991,14 +953,12 @@ def test_namespaced_process_outputs(self): namespace_nested = f'{namespace}.nested' class OutputMode(enum.Enum): - NONE = 0 DYNAMIC_PORT_NAMESPACE = 1 SINGLE_REQUIRED_PORT = 2 BOTH_SINGLE_AND_NAMESPACE = 3 class DummyDynamicProcess(Process): - @classmethod def define(cls, spec): super().define(spec) @@ -1057,7 +1017,6 @@ def run(self): class TestProcessEvents(unittest.TestCase): - def test_basic_events(self): proc = utils.DummyProcessWithOutput() events_tester = utils.ProcessListenerTester( @@ -1077,11 +1036,14 @@ def test_killed(self): def test_excepted(self): proc = utils.ExceptionProcess() - events_tester = utils.ProcessListenerTester(proc, ( - 'excepted', - 'running', - 'output_emitted', - )) + events_tester = utils.ProcessListenerTester( + proc, + ( + 'excepted', + 'running', + 'output_emitted', + ), + ) with self.assertRaises(RuntimeError): proc.execute() proc.result() @@ -1120,7 +1082,6 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id): class _RestartProcess(utils.WaitForSignalProcess): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/test_utils.py b/test/test_utils.py index 546261f2..c01d712b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,12 +5,10 @@ import warnings import pytest - from plumpy.utils import AttributesFrozendict, ensure_coroutine, load_function class TestAttributesFrozendict: - def test_getitem(self): d = AttributesFrozendict({'a': 5}) assert d['a'] == 5 @@ -40,7 +38,6 @@ async def async_fct(): class TestEnsureCoroutine: - def test_sync_func(self): coro = ensure_coroutine(fct) assert inspect.iscoroutinefunction(coro) @@ -50,9 +47,7 @@ def test_async_func(self): assert coro is async_fct def test_callable_class(self): - class AsyncDummy: - async def __call__(self): pass @@ -60,9 +55,7 @@ async def __call__(self): assert coro is AsyncDummy def test_callable_object(self): - class AsyncDummy: - async def __call__(self): pass diff --git a/test/test_waiting_process.py b/test/test_waiting_process.py index 87d39192..90427554 100644 --- a/test/test_waiting_process.py +++ b/test/test_waiting_process.py @@ -9,7 +9,6 @@ class TestWaitingProcess(unittest.TestCase): - def test_instance_state(self): proc = utils.ThreeSteps() wl = utils.ProcessSaver(proc) diff --git a/test/test_workchains.py b/test/test_workchains.py index c698aff9..9de0d1fe 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -3,9 +3,8 @@ import inspect import unittest -import pytest - import plumpy +import pytest from plumpy.process_listener import ProcessListener from plumpy.workchains import * @@ -33,9 +32,17 @@ def on_create(self): super().on_create() # Reset the finished step self.finished_steps = { - k: False for k in [ - self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, - self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__ + k: False + for k in [ + self.s1.__name__, + self.s2.__name__, + self.s3.__name__, + self.s4.__name__, + self.s5.__name__, + self.s6.__name__, + self.isA.__name__, + self.isB.__name__, + self.ltN.__name__, ] } @@ -78,7 +85,6 @@ def _set_finished(self, function_name): class IfTest(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -101,7 +107,6 @@ def step2(self): class DummyWc(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -112,7 +117,6 @@ def do_nothing(self): class TestContext(unittest.TestCase): - def test_attributes(self): wc = DummyWc() wc.ctx.new_attr = 5 @@ -163,9 +167,7 @@ def test_run(self): self.assertTrue(finished, f'Step {step} was not called by workflow') def test_incorrect_outline(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -176,9 +178,7 @@ def define(cls, spec): Wf.spec() def test_same_input_node(self): - class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -199,7 +199,6 @@ def test_context(self): B = 'b' class ReturnA(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -209,7 +208,6 @@ def run(self): self.out('res', A) class ReturnB(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -219,7 +217,6 @@ def run(self): self.out('res', B) class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -288,13 +285,11 @@ def test_listener_persistence(self): process_finished_count = 0 class TestListener(plumpy.ProcessListener): - def on_process_finished(self, process, output): nonlocal process_finished_count process_finished_count += 1 class SimpleWorkChain(plumpy.WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -324,7 +319,6 @@ def step2(self): self.assertEqual(process_finished_count, 2) def test_return_in_outline(self): - class WcWithReturn(WorkChain): FAILED_CODE = 1 @@ -360,9 +354,7 @@ def default(self): workchain.execute() def test_return_in_step(self): - class WcWithReturn(WorkChain): - FAILED_CODE = 1 @classmethod @@ -393,9 +385,7 @@ def after(self): workchain.execute() def test_tocontext_schedule_workchain(self): - class MainWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -409,7 +399,6 @@ def check(self): assert self.ctx.subwc.out.value == 5 class SubWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -453,7 +442,6 @@ def test_to_context(self): val = 5 class SimpleWc(plumpy.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -463,7 +451,6 @@ def run(self): self.out('_return', val) class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -484,7 +471,6 @@ def test_output_namespace(self): """Test running a workchain with nested outputs.""" class TestWorkChain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -501,7 +487,6 @@ def test_exception_tocontext(self): my_exception = RuntimeError('Should not be reached') class Workchain(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -528,7 +513,6 @@ def test_stepper_info(self): """Check status information provided by steppers""" class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -539,7 +523,13 @@ def define(cls, spec): cls.chill, cls.chill, ), - if_(cls.do_step)(cls.chill,).elif_(cls.do_step)(cls.chill,).else_(cls.chill), + if_(cls.do_step)( + cls.chill, + ) + .elif_(cls.do_step)( + cls.chill, + ) + .else_(cls.chill), ) def check_n(self): @@ -560,7 +550,6 @@ def do_step(self): return False class StatusCollector(ProcessListener): - def __init__(self): self.stepper_strings = [] @@ -574,9 +563,15 @@ def on_process_running(self, process): wf.execute() stepper_strings = [ - '0:check_n', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '1:while_(do_step)(1:chill)', '1:while_(do_step)', '1:while_(do_step)(1:chill)', '1:while_(do_step)', - '2:if_(do_step)' + '0:check_n', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '1:while_(do_step)(1:chill)', + '1:while_(do_step)', + '2:if_(do_step)', ] self.assertListEqual(collector.stepper_strings, stepper_strings) @@ -593,7 +588,6 @@ def test_immutable_input(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) @@ -630,7 +624,6 @@ def test_immutable_input_namespace(self): test_class = self class Wf(WorkChain): - @classmethod def define(cls, spec): super().define(spec) diff --git a/test/utils.py b/test/utils.py index 1f7408f6..22bce906 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- """Utilities for tests""" + import asyncio import collections -from collections.abc import Mapping import unittest - -import kiwipy.rmq -import shortuuid +from collections.abc import Mapping import plumpy from plumpy import persistence, process_states, processes, utils @@ -24,7 +22,9 @@ class DummyProcess(processes.Process): """ EXPECTED_STATE_SEQUENCE = [ - process_states.ProcessState.CREATED, process_states.ProcessState.RUNNING, process_states.ProcessState.FINISHED + process_states.ProcessState.CREATED, + process_states.ProcessState.RUNNING, + process_states.ProcessState.FINISHED, ] EXPECTED_OUTPUTS = {} @@ -58,14 +58,12 @@ def run(self, **kwargs): class KeyboardInterruptProc(processes.Process): - @utils.override def run(self): raise KeyboardInterrupt() class ProcessWithCheckpoint(processes.Process): - @utils.override def run(self): return process_states.Continue(self.last_step) @@ -75,7 +73,6 @@ def last_step(self): class WaitForSignalProcess(processes.Process): - @utils.override def run(self): return process_states.Wait(self.last_step) @@ -85,14 +82,13 @@ def last_step(self): class KillProcess(processes.Process): - @utils.override def run(self): return process_states.Kill('killed') class MissingOutputProcess(processes.Process): - """ A process that does not generate a required output """ + """A process that does not generate a required output""" @classmethod def define(cls, spec): @@ -101,7 +97,6 @@ def define(cls, spec): class NewLoopProcess(processes.Process): - def __init__(self, *args, **kwargs): kwargs['loop'] = plumpy.new_event_loop() super().__init__(*args, **kwargs) @@ -118,8 +113,7 @@ def called(cls, event): cls.called_events.append(event) def __init__(self, *args, **kwargs): - assert isinstance(self, processes.Process), \ - 'Mixin has to be used with a type derived from a Process' + assert isinstance(self, processes.Process), 'Mixin has to be used with a type derived from a Process' super().__init__(*args, **kwargs) self.__class__.called_events = [] @@ -165,7 +159,6 @@ def on_terminate(self): class ProcessEventsTester(EventsTesterMixin, processes.Process): - @classmethod def define(cls, spec): super().define(spec) @@ -193,7 +186,6 @@ def last_step(self): class TwoCheckpointNoFinish(ProcessEventsTester): - def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -203,21 +195,18 @@ def middle_step(self): class ExceptionProcess(ProcessEventsTester): - def run(self): self.out('test', 5) raise RuntimeError('Great scott!') class ThreeStepsThenException(ThreeSteps): - @utils.override def last_step(self): raise RuntimeError('Great scott!') class ProcessListenerTester(plumpy.ProcessListener): - def __init__(self, process, expected_events): process.add_process_listener(self) self.expected_events = set(expected_events) @@ -249,7 +238,6 @@ def on_process_killed(self, process, msg): class Saver: - def __init__(self): self.snapshots = [] self.outputs = [] @@ -357,7 +345,11 @@ def on_process_killed(self, process, msg): TEST_PROCESSES = [DummyProcess, DummyProcessWithOutput, DummyProcessWithDynamicOutput, ThreeSteps] TEST_WAITING_PROCESSES = [ - ProcessWithCheckpoint, TwoCheckpointNoFinish, ExceptionProcess, ProcessEventsTester, ThreeStepsThenException + ProcessWithCheckpoint, + TwoCheckpointNoFinish, + ExceptionProcess, + ProcessEventsTester, + ThreeStepsThenException, ] TEST_EXCEPTION_PROCESSES = [ExceptionProcess, ThreeStepsThenException, MissingOutputProcess] @@ -402,7 +394,7 @@ def check_process_against_snapshots(loop, proc_class, snapshots): saver.snapshots[-j], snapshots[-j], saver.snapshots[-j], - exclude={'exception', '_listeners'} + exclude={'exception', '_listeners'}, ) j += 1 @@ -438,9 +430,8 @@ def compare_value(bundle1, bundle2, v1, v2, exclude=None): compare_value(bundle1, bundle2, list(v1), list(v2), exclude) elif isinstance(v1, set) and isinstance(v2, set): raise NotImplementedError('Comparison between sets not implemented') - else: - if v1 != v2: - raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') + elif v1 != v2: + raise ValueError(f'Dict values mismatch for :\n{v1} != {v2}') class TestPersister(persistence.Persister): @@ -449,7 +440,7 @@ class TestPersister(persistence.Persister): """ def save_checkpoint(self, process, tag=None): - """ Create the checkpoint bundle """ + """Create the checkpoint bundle""" persistence.Bundle(process) def load_checkpoint(self, pid, tag=None): @@ -469,7 +460,7 @@ def delete_process_checkpoints(self, pid): def run_until_waiting(proc): - """ Set up a future that will be resolved on entering the WAITING state """ + """Set up a future that will be resolved on entering the WAITING state""" from plumpy import ProcessState listener = plumpy.ProcessListener() @@ -490,7 +481,7 @@ def on_waiting(_waiting_proc): def run_until_paused(proc): - """ Set up a future that will be resolved when the process is paused """ + """Set up a future that will be resolved when the process is paused""" listener = plumpy.ProcessListener() paused = plumpy.Future()