From d72fcd24b04c87311e50b6f6ca6b436a095fdedc Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 27 Nov 2024 15:49:58 +0100 Subject: [PATCH] Address mypy complains --- src/plumpy/communications.py | 2 +- src/plumpy/lang.py | 4 ++-- src/plumpy/workchains.py | 22 +++++++++++++++++++--- test/test_workchains.py | 2 ++ 4 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/plumpy/communications.py b/src/plumpy/communications.py index f4941d0f..01bac433 100644 --- a/src/plumpy/communications.py +++ b/src/plumpy/communications.py @@ -81,7 +81,7 @@ def convert_to_comm( def _passthrough(*args: Any, **kwargs: Any) -> bool: sender = kwargs.get('sender', args[1]) subject = kwargs.get('subject', args[2]) - return callback.is_filtered(sender, subject) # type: ignore[attr-defined] + return callback.is_filtered(sender, subject) else: def _passthrough(*args: Any, **kwargs: Any) -> bool: # pylint: disable=unused-argument diff --git a/src/plumpy/lang.py b/src/plumpy/lang.py index aad2abb6..450927d6 100644 --- a/src/plumpy/lang.py +++ b/src/plumpy/lang.py @@ -31,7 +31,7 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn @@ -60,7 +60,7 @@ def wrapped_fn(self: Any, *args: Any, **kwargs: Any) -> Callable[..., Any]: return func(self, *args, **kwargs) else: - wrapped_fn = func + wrapped_fn = func # type: ignore[assignment] return wrapped_fn diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 60208c09..b7690c82 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -1,11 +1,27 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import abc import asyncio import collections import inspect import logging import re -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast +from typing import ( + Any, + Callable, + Dict, + Hashable, + List, + Mapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) import kiwipy @@ -327,7 +343,7 @@ class _Block(_Instruction, collections.abc.Sequence): def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) -> None: # Build up the list of commands - comms = [] + comms: MutableSequence[_Instruction | _FunctionCall] = [] for instruction in instructions: if not isinstance(instruction, _Instruction): # Assume it's a function call @@ -335,7 +351,7 @@ def __init__(self, instructions: Sequence[Union[_Instruction, WC_COMMAND_TYPE]]) else: comms.append(instruction) - self._instruction: List[Union[_Instruction, _FunctionCall]] = comms + self._instruction: MutableSequence[_Instruction | _FunctionCall] = comms def __getitem__(self, index: int) -> Union[_Instruction, _FunctionCall]: # type: ignore return self._instruction[index] diff --git a/test/test_workchains.py b/test/test_workchains.py index 795a4997..1335517f 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -311,6 +311,8 @@ def step2(self): workchain = SimpleWorkChain() workchain.add_process_listener(TestListener()) + workchain.execute() + self.assertEqual(process_finished_count, 1) workchain_checkpoint = persister.load_checkpoint(workchain.pid, 'step1').unbundle()