From 4ef293a297ed6c7d3709cf165095a24d335fb31d Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 1 Sep 2023 11:54:45 +0200 Subject: [PATCH] Devops: Update `pyproject.toml` configuration (#6085) Added stricter rules for `mypy` and `pytest`. Suggestions taken after automated analysis by the following tool: https://learn.scientific-python.org/development/guides/repo-review/ --- aiida/cmdline/commands/cmd_config.py | 10 ++-- aiida/cmdline/groups/verdi.py | 2 +- aiida/common/lang.py | 2 +- aiida/common/progress_reporter.py | 2 +- aiida/engine/daemon/client.py | 3 +- aiida/engine/launch.py | 3 +- aiida/engine/processes/builder.py | 2 +- aiida/engine/processes/calcjobs/calcjob.py | 10 ++-- aiida/engine/processes/functions.py | 20 ++++--- aiida/engine/processes/process.py | 19 +++---- aiida/engine/processes/utils.py | 6 ++- aiida/engine/processes/workchains/restart.py | 3 +- .../engine/processes/workchains/workchain.py | 4 +- aiida/engine/runners.py | 4 +- aiida/manage/caching.py | 2 +- aiida/orm/autogroup.py | 19 +++---- aiida/orm/entities.py | 10 ++-- aiida/orm/groups.py | 6 +-- aiida/orm/implementation/entities.py | 4 +- aiida/orm/nodes/data/code/abstract.py | 2 +- aiida/orm/nodes/data/code/installed.py | 2 +- aiida/orm/nodes/data/enum.py | 2 +- aiida/orm/nodes/links.py | 2 +- aiida/orm/nodes/node.py | 4 +- .../nodes/process/calculation/calcfunction.py | 2 +- aiida/orm/nodes/process/workflow/workchain.py | 2 +- .../nodes/process/workflow/workfunction.py | 2 +- aiida/orm/nodes/repository.py | 8 +-- aiida/orm/querybuilder.py | 18 +++---- aiida/orm/utils/links.py | 2 +- aiida/parsers/parser.py | 6 +-- .../plugins/templatereplacer/parser.py | 2 +- aiida/repository/backend/abstract.py | 7 ++- aiida/schedulers/datastructures.py | 6 +-- aiida/storage/psql_dos/alembic_cli.py | 7 ++- aiida/storage/psql_dos/backend.py | 8 +-- .../migrations/utils/create_dbattribute.py | 2 +- aiida/tools/archive/create.py | 52 ++++++++++++++----- .../implementations/sqlite_zip/writer.py | 2 +- aiida/tools/archive/imports.py | 16 +++--- aiida/tools/visualization/graph.py | 7 ++- pyproject.toml | 9 +++- tests/cmdline/commands/test_calcjob.py | 2 +- tests/cmdline/commands/test_data.py | 2 +- tests/cmdline/commands/test_process.py | 2 +- tests/engine/test_process_function.py | 2 +- tests/plugins/test_entry_point.py | 8 ++- ...st_0037_attributes_extras_settings_json.py | 4 +- ..._data_migration_legacy_job_calculations.py | 2 +- tests/test_conftest.py | 2 +- tests/test_dbimporters.py | 2 +- .../archive/migration/test_v04_to_v05.py | 2 +- .../archive/migration/test_v05_to_v06.py | 2 +- 53 files changed, 194 insertions(+), 137 deletions(-) diff --git a/aiida/cmdline/commands/cmd_config.py b/aiida/cmdline/commands/cmd_config.py index 6e010e4625..80ce61a15a 100644 --- a/aiida/cmdline/commands/cmd_config.py +++ b/aiida/cmdline/commands/cmd_config.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """`verdi config` command.""" +from __future__ import annotations + import json from pathlib import Path import textwrap @@ -40,7 +42,7 @@ def verdi_config_list(ctx, prefix, description: bool): from aiida.manage.configuration import Config, Profile config: Config = ctx.obj.config - profile: Profile = ctx.obj.get('profile', None) + profile: Profile | None = ctx.obj.get('profile', None) if not profile: echo.echo_warning('no profiles configured: run `verdi setup` to create one') @@ -78,7 +80,7 @@ def verdi_config_show(ctx, option): from aiida.manage.configuration.options import NO_DEFAULT config: Config = ctx.obj.config - profile: Profile = ctx.obj.profile + profile: Profile | None = ctx.obj.profile dct = { 'schema': option.schema, @@ -124,7 +126,7 @@ def verdi_config_set(ctx, option, value, globally, append, remove): echo.echo_critical('Cannot flag both append and remove') config: Config = ctx.obj.config - profile: Profile = ctx.obj.profile + profile: Profile | None = ctx.obj.profile if option.global_only: globally = True @@ -164,7 +166,7 @@ def verdi_config_unset(ctx, option, globally): from aiida.manage.configuration import Config, Profile config: Config = ctx.obj.config - profile: Profile = ctx.obj.profile + profile: Profile | None = ctx.obj.profile if option.global_only: globally = True diff --git a/aiida/cmdline/groups/verdi.py b/aiida/cmdline/groups/verdi.py index 64a08ce8c8..d0bd9ed49f 100644 --- a/aiida/cmdline/groups/verdi.py +++ b/aiida/cmdline/groups/verdi.py @@ -62,7 +62,7 @@ class VerdiCommandGroup(click.Group): def add_verbosity_option(cmd: click.Command): """Apply the ``verbosity`` option to the command, which is common to all ``verdi`` commands.""" # Only apply the option if it hasn't been already added in a previous call. - if cmd is not None and 'verbosity' not in [param.name for param in cmd.params]: + if 'verbosity' not in [param.name for param in cmd.params]: cmd = options.VERBOSITY()(cmd) return cmd diff --git a/aiida/common/lang.py b/aiida/common/lang.py index 3df6bac4b0..d8598eafc3 100644 --- a/aiida/common/lang.py +++ b/aiida/common/lang.py @@ -72,7 +72,7 @@ def wrapped_fn(self, *args, **kwargs): # pylint: disable=missing-docstring else: wrapped_fn = func - return wrapped_fn # type: ignore + return wrapped_fn # type: ignore[return-value] return wrap diff --git a/aiida/common/progress_reporter.py b/aiida/common/progress_reporter.py index 2c8166d0d4..9633946d58 100644 --- a/aiida/common/progress_reporter.py +++ b/aiida/common/progress_reporter.py @@ -156,7 +156,7 @@ def set_progress_reporter(reporter: Optional[Type[ProgressReporterAbstract]] = N if reporter is None: PROGRESS_REPORTER = ProgressReporterNull elif kwargs: - PROGRESS_REPORTER = partial(reporter, **kwargs) # type: ignore + PROGRESS_REPORTER = partial(reporter, **kwargs) # type: ignore[assignment] else: PROGRESS_REPORTER = reporter diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index d5f804eee4..7274bc2fb4 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -713,8 +713,7 @@ def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> No pidfile.create(os.getpid()) # Configure the logger - loggerconfig = None - loggerconfig = loggerconfig or arbiter.loggerconfig or None + loggerconfig = arbiter.loggerconfig or None configure_logger(circus_logger, loglevel, logoutput, loggerconfig) # Main loop diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index 888536cd61..79e2aff066 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -17,6 +17,7 @@ from .processes.builder import ProcessBuilder from .processes.functions import FunctionProcess from .processes.process import Process +from .runners import ResultAndPk from .utils import instantiate_process, is_process_scoped # pylint: disable=no-name-in-module __all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') @@ -60,7 +61,7 @@ def run_get_node(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[ return runner.run_get_node(process, *args, **inputs) -def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], int]: +def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndPk: """Run the process with the supplied inputs in a local runner that will block until the process is completed. :param process: the process class, instance, builder or function to run diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index 667d6fa382..8a50b1426f 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -79,7 +79,7 @@ def fgetter(self, name=name): return self._data.get(name) elif port.has_default(): - def fgetter(self, name=name, default=port.default): # type: ignore # pylint: disable=cell-var-from-loop + def fgetter(self, name=name, default=port.default): # type: ignore[misc] # pylint: disable=cell-var-from-loop return self._data.get(name, default) else: diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index 475fa94e4d..866d41cc26 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -517,7 +517,7 @@ def get_state_classes(cls) -> Dict[Hashable, Type[plumpy.process_states.State]]: @property def node(self) -> orm.CalcJobNode: - return super().node # type: ignore + return super().node # type: ignore[return-value] @override def on_terminated(self) -> None: @@ -616,7 +616,7 @@ def _perform_dry_run(self): calc_info = self.presubmit(folder) transport.chdir(folder.abspath) upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) - self.node.dry_run_info = { # type: ignore + self.node.dry_run_info = { # type: ignore[attr-defined] 'folder': folder.abspath, 'script_filename': self.node.get_option('submit_script_filename') } @@ -768,7 +768,7 @@ def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: return None if exit_code is not None and not isinstance(exit_code, ExitCode): - args = (scheduler.__class__.__name__, type(exit_code)) + args = (scheduler.__class__.__name__, type(exit_code)) # type: ignore[unreachable] raise ValueError('`{}.parse_output` returned neither an `ExitCode` nor None, but: {}'.format(*args)) return exit_code @@ -797,7 +797,7 @@ def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = Non break if exit_code is not None and not isinstance(exit_code, ExitCode): - args = (parser_class.__name__, type(exit_code)) + args = (parser_class.__name__, type(exit_code)) # type: ignore[unreachable] raise ValueError('`{}.parse` returned neither an `ExitCode` nor None, but: {}'.format(*args)) return exit_code @@ -894,7 +894,7 @@ def presubmit(self, folder: Folder) -> CalcInfo: # Set resources, also with get_default_mpiprocs_per_machine resources = self.node.get_option('resources') scheduler.preprocess_resources(resources or {}, computer.get_default_mpiprocs_per_machine()) - job_tmpl.job_resource = scheduler.create_job_resource(**resources) # type: ignore + job_tmpl.job_resource = scheduler.create_job_resource(**resources) # type: ignore[arg-type] subst_dict = {'tot_num_mpiprocs': job_tmpl.job_resource.get_tot_num_mpiprocs()} diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index 8baf92c903..f932f09569 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -81,7 +81,7 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return] for size in itertools.count(size, 8): # pylint: disable=redefined-argument-from-local frame = frame.f_back.f_back.f_back.f_back.f_back.f_back.f_back.f_back # type: ignore[assignment,union-attr] except AttributeError: - while frame: + while frame: # type: ignore[truthy-bool] frame = frame.f_back # type: ignore[assignment] size += 1 return size - 1 @@ -234,6 +234,7 @@ def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]: """ result, node = run_get_node(*args, **kwargs) + assert node.pk is not None return result, node.pk @functools.wraps(function) @@ -323,10 +324,13 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func """ # pylint: disable=too-many-statements - if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): + if ( + not issubclass(node_class, ProcessNode) or # type: ignore[redundant-expr] + not issubclass(node_class, FunctionCalculationMixin) # type: ignore[unreachable] + ): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') - signature = inspect.signature(func) + signature = inspect.signature(func) # type: ignore[unreachable] args: list[str] = [] varargs: str | None = None @@ -519,7 +523,7 @@ def get_or_create_db_record(cls) -> 'ProcessNode': def __init__(self, *args, **kwargs) -> None: if kwargs.get('enable_persistence', False): raise RuntimeError('Cannot persist a function process') - super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore + super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore[misc] @property def process_class(self) -> t.Callable[..., t.Any]: @@ -586,11 +590,11 @@ def run(self) -> 'ExitCode' | None: result = self._func(*args, **kwargs) - if result is None or isinstance(result, ExitCode): - return result + if result is None or isinstance(result, ExitCode): # type: ignore[redundant-expr] + return result # type: ignore[unreachable] - if isinstance(result, Data): - self.out(self.SINGLE_OUTPUT_LINKNAME, result) + if isinstance(result, Data): # type: ignore[unreachable] + self.out(self.SINGLE_OUTPUT_LINKNAME, result) # type: ignore[unreachable] elif isinstance(result, collections.abc.Mapping): for name, value in result.items(): self.out(name, value) diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index eebfc3bc52..9ec34bf6e7 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -251,7 +251,6 @@ def metadata(self) -> AttributeDict: """ try: - assert self.inputs is not None return self.inputs.metadata except (AssertionError, AttributeError): return AttributeDict() @@ -297,7 +296,6 @@ def get_provenance_inputs_iterator(self) -> Iterator[Tuple[str, Union[InputPort, :rtype: filter """ - assert self.inputs is not None return filter(lambda kv: not kv[0].startswith('_'), self.inputs.items()) @override @@ -321,7 +319,7 @@ def load_instance_state( super().load_instance_state(saved_state, load_context) if self.SaveKeys.CALC_ID.value in saved_state: - self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) # type: ignore + self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) # type: ignore[assignment] self._pid = self.node.pk # pylint: disable=attribute-defined-outside-init else: self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init @@ -429,7 +427,7 @@ def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: except ValueError: # pylint: disable=try-except-raise raise finally: - self.node.set_process_state(self._state.LABEL) # type: ignore + self.node.set_process_state(self._state.LABEL) # type: ignore[arg-type] self._save_checkpoint() set_process_state_change_timestamp(self) @@ -464,7 +462,7 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: self.report(''.join(traceback.format_exception(*exc_info))) @override - def on_finish(self, result: Union[int, ExitCode], successful: bool) -> None: + def on_finish(self, result: Union[int, ExitCode, None], successful: bool) -> None: """ Set the finish status on the process node. :param result: result of the process @@ -559,7 +557,7 @@ def get_parent_calc(self) -> Optional[orm.ProcessNode]: if self._parent_pid is None: return None - return orm.load_node(pk=self._parent_pid) # type: ignore + return orm.load_node(pk=self._parent_pid) # type: ignore[return-value] @classmethod def build_process_type(cls) -> str: @@ -702,7 +700,6 @@ def _setup_db_record(self) -> None: In addition, the parent calculation will be setup with a CALL link if applicable and all inputs will be linked up as well. """ - assert self.inputs is not None assert not self.node.is_sealed, 'process node cannot be sealed when setting up the database record' # Store important process attributes in the node proxy @@ -731,9 +728,6 @@ def _setup_version_info(self) -> None: """Store relevant plugin version information.""" from aiida.plugins.entry_point import format_entry_point_string - if self.inputs is None: - return - version_info = self.runner.plugin_version_provider.get_version_info(self.__class__) for key, monitor in self.inputs.get('monitors', {}).items(): @@ -836,7 +830,6 @@ def _flat_inputs(self) -> Dict[str, Any]: :return: flat dictionary of parsed inputs """ - assert self.inputs is not None inputs = {key: value for key, value in self.inputs.items() if key != self.spec().metadata_key} return dict(self._flatten_inputs(self.spec().inputs, inputs)) @@ -890,7 +883,9 @@ def _flatten_inputs( items.extend(sub_items) return items - assert (port is None) or (isinstance(port, InputPort) and (port.is_metadata or port.non_db)) + assert (port is None) or ( + isinstance(port, InputPort) and (port.is_metadata or port.non_db) # type: ignore[redundant-expr] + ) return [] def _flatten_outputs( diff --git a/aiida/engine/processes/utils.py b/aiida/engine/processes/utils.py index 340131a78b..44c74728e0 100644 --- a/aiida/engine/processes/utils.py +++ b/aiida/engine/processes/utils.py @@ -14,12 +14,14 @@ def prune_mapping(value): :param value: A nested mapping of port values. :return: The same mapping but without any nested namespace that is completely empty. """ - if isinstance(value, Mapping) and not isinstance(value, Node): + if isinstance(value, Mapping) and not isinstance(value, Node): # type: ignore[unreachable] result = {} for key, sub_value in value.items(): pruned = prune_mapping(sub_value) # If `pruned` is an "empty'ish" mapping and not an instance of `Node`, skip it, otherwise keep it. - if not (isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node)): + if not ( + isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node) # type: ignore[unreachable] + ): result[key] = pruned return result diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 2b4c544d18..8d299b8a6b 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -427,7 +427,8 @@ def _wrap_bare_dict_inputs(self, port_namespace: 'PortNamespace', inputs: Dict[s continue port = port_namespace[key] - valid_types = port.valid_type if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) + valid_types = port.valid_type \ + if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) # type: ignore[redundant-expr] if isinstance(port, PortNamespace): wrapped[key] = self._wrap_bare_dict_inputs(port, value) diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index e6ca21a4b4..e17f816c37 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -139,7 +139,7 @@ def spec(cls) -> WorkChainSpec: @property def node(self) -> WorkChainNode: - return super().node # type: ignore + return super().node # type: ignore[return-value] @property def ctx(self) -> AttributeDict: @@ -408,7 +408,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: if awaitable.outputs: value = {entry.link_label: entry.node for entry in node.base.links.get_outgoing()} else: - value = node # type: ignore + value = node # type: ignore[assignment] self._resolve_awaitable(awaitable, value) diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index 3544bc90c0..aca4a8b8da 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -9,6 +9,8 @@ ########################################################################### # pylint: disable=global-statement """Runners that can run and submit processes.""" +from __future__ import annotations + import asyncio import functools import logging @@ -43,7 +45,7 @@ class ResultAndNode(NamedTuple): class ResultAndPk(NamedTuple): result: Dict[str, Any] - pk: int + pk: int | None TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py index 2dfea4f9f4..3c78b119b7 100644 --- a/aiida/manage/caching.py +++ b/aiida/manage/caching.py @@ -40,7 +40,7 @@ def __init__(self): def clear(self): """Clear caching overrides.""" - self.__init__() # type: ignore + self.__init__() # type: ignore[misc] def enable_all(self): self._default_all = 'enable' diff --git a/aiida/orm/autogroup.py b/aiida/orm/autogroup.py index e685451b6d..2aadcfb72a 100644 --- a/aiida/orm/autogroup.py +++ b/aiida/orm/autogroup.py @@ -8,8 +8,9 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module to manage the autogrouping functionality by ``verdi run``.""" +from __future__ import annotations + import re -from typing import List, Optional from aiida.common import exceptions, timezone from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql @@ -44,8 +45,8 @@ def __init__(self, backend): self._backend = backend self._enabled = False - self._exclude: Optional[List[str]] = None - self._include: Optional[List[str]] = None + self._exclude: list[str] | None = None + self._include: list[str] | None = None self._group_label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}" self._group_label = None # Actual group label, set by `get_or_create_group` @@ -63,13 +64,13 @@ def disable(self) -> None: """Disable the auto-grouping.""" self._enabled = False - def get_exclude(self) -> Optional[List[str]]: + def get_exclude(self) -> list[str] | None: """Return the list of classes to exclude from autogrouping. Returns ``None`` if no exclusion list has been set.""" return self._exclude - def get_include(self) -> Optional[List[str]]: + def get_include(self) -> list[str] | None: """Return the list of classes to include in the autogrouping. Returns ``None`` if no inclusion list has been set.""" @@ -81,7 +82,7 @@ def get_group_label_prefix(self) -> str: return self._group_label_prefix @staticmethod - def validate(strings: Optional[List[str]]): + def validate(strings: list[str] | None): """Validate the list of strings passed to set_include and set_exclude.""" if strings is None: return @@ -97,7 +98,7 @@ def validate(strings: Optional[List[str]]): f"'{string}' has an invalid prefix, must be among: {sorted(valid_prefixes)}" ) - def set_exclude(self, exclude: Optional[List[str]]) -> None: + def set_exclude(self, exclude: list[str] | str | None) -> None: """Set the list of classes to exclude in the autogrouping. :param exclude: a list of valid entry point strings (might contain '%' to be used as @@ -112,7 +113,7 @@ def set_exclude(self, exclude: Optional[List[str]]) -> None: raise exceptions.ValidationError('Cannot both specify exclude and include') self._exclude = exclude - def set_include(self, include: Optional[List[str]]) -> None: + def set_include(self, include: list[str] | str | None) -> None: """Set the list of classes to include in the autogrouping. :param include: a list of valid entry point strings (might contain '%' to be used as @@ -127,7 +128,7 @@ def set_include(self, include: Optional[List[str]]) -> None: raise exceptions.ValidationError('Cannot both specify exclude and include') self._include = include - def set_group_label_prefix(self, label_prefix: Optional[str]) -> None: + def set_group_label_prefix(self, label_prefix: str | None) -> None: """Set the label of the group to be created (or use a default).""" if label_prefix is None: label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}" diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index b15feaba3e..cec8122d88 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module for all common top level AiiDA entity classes and methods""" +from __future__ import annotations + import abc from enum import Enum from functools import lru_cache @@ -79,7 +81,7 @@ def __call__(self: CollectionType, backend: 'StorageBackend') -> CollectionType: """Get or create a cached collection using a new backend.""" if backend is self._backend: return self - return self.get_cached(self.entity_type, backend=backend) # type: ignore + return self.get_cached(self.entity_type, backend=backend) # type: ignore[arg-type] @property def entity_type(self) -> Type[EntityType]: @@ -162,7 +164,7 @@ def count(self, filters: Optional['FilterType'] = None) -> int: class Entity(abc.ABC, Generic[BackendEntityType, CollectionType]): """An AiiDA entity""" - _CLS_COLLECTION: Type[CollectionType] = Collection # type: ignore + _CLS_COLLECTION: Type[CollectionType] = Collection # type: ignore[assignment] @classproperty def objects(cls: EntityType) -> CollectionType: # pylint: disable=no-self-argument @@ -216,7 +218,7 @@ def initialize(self) -> None: """ @property - def id(self) -> int: # pylint: disable=invalid-name + def id(self) -> int | None: # pylint: disable=invalid-name """Return the id for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. @@ -229,7 +231,7 @@ def id(self) -> int: # pylint: disable=invalid-name return self._backend_entity.id @property - def pk(self) -> int: + def pk(self) -> int | None: """Return the primary key for this entity. This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index f65c9d9394..83adb822c8 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from aiida.orm import Node, User from aiida.orm.implementation import BackendGroup, StorageBackend - from aiida.plugins.entry_point import EntryPoint # type: ignore + from aiida.plugins.entry_point import EntryPoint # type: ignore[attr-defined] __all__ = ('Group', 'AutoGroup', 'ImportGroup', 'UpfFamily') @@ -305,7 +305,7 @@ def add_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: # Cannot use `collections.Iterable` here, because that would also match iterable `Node` sub classes like `List` if not isinstance(nodes, (list, tuple)): - nodes = [nodes] # type: ignore + nodes = [nodes] # type: ignore[list-item] for node in nodes: type_check(node, Node) @@ -326,7 +326,7 @@ def remove_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: # Cannot use `collections.Iterable` here, because that would also match iterable `Node` sub classes like `List` if not isinstance(nodes, (list, tuple)): - nodes = [nodes] # type: ignore + nodes = [nodes] # type: ignore[list-item] for node in nodes: type_check(node, Node) diff --git a/aiida/orm/implementation/entities.py b/aiida/orm/implementation/entities.py index 41f8e8b988..52320777b3 100644 --- a/aiida/orm/implementation/entities.py +++ b/aiida/orm/implementation/entities.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Classes and methods for backend non-specific entities""" +from __future__ import annotations + import abc from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterable, List, Tuple, Type, TypeVar @@ -44,7 +46,7 @@ def id(self) -> int: # pylint: disable=invalid-name """ @property - def pk(self) -> int: + def pk(self) -> int | None: """Return the id for this entity. This is unique only amongst entities of this type for a particular backend. diff --git a/aiida/orm/nodes/data/code/abstract.py b/aiida/orm/nodes/data/code/abstract.py index afd53ebfd1..721049879f 100644 --- a/aiida/orm/nodes/data/code/abstract.py +++ b/aiida/orm/nodes/data/code/abstract.py @@ -301,7 +301,7 @@ def get_builder(self) -> 'ProcessBuilder': except exceptions.EntryPointError: raise exceptions.EntryPointError(f'The calculation entry point `{entry_point}` could not be loaded') - builder = process_class.get_builder() # type: ignore + builder = process_class.get_builder() # type: ignore[union-attr] builder.code = self return builder diff --git a/aiida/orm/nodes/data/code/installed.py b/aiida/orm/nodes/data/code/installed.py index da0ea33876..6b0c3397ce 100644 --- a/aiida/orm/nodes/data/code/installed.py +++ b/aiida/orm/nodes/data/code/installed.py @@ -53,7 +53,7 @@ def _validate(self): """ super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call - if not self.computer: + if not self.computer: # type: ignore[truthy-bool] raise exceptions.ValidationError('The `computer` is undefined.') try: diff --git a/aiida/orm/nodes/data/enum.py b/aiida/orm/nodes/data/enum.py index 0d08767be9..8c74c043f8 100644 --- a/aiida/orm/nodes/data/enum.py +++ b/aiida/orm/nodes/data/enum.py @@ -83,7 +83,7 @@ def get_enum(self) -> t.Type[EnumType]: except ValueError as exc: raise ImportError(f'Could not reconstruct enum class because `{identifier}` could not be loaded.') from exc - def get_member(self) -> EnumType: # type: ignore + def get_member(self) -> EnumType: # type: ignore[misc, type-var] """Return the enum member reconstructed from the serialized data stored in the database. For the enum member to be successfully reconstructed, the class of course has to still be importable and its diff --git a/aiida/orm/nodes/links.py b/aiida/orm/nodes/links.py index 6b0a83842e..46e157eadf 100644 --- a/aiida/orm/nodes/links.py +++ b/aiida/orm/nodes/links.py @@ -200,7 +200,7 @@ def get_incoming( if only_uuid: link_triple = LinkTriple( - link_triple.node.uuid, # type: ignore + link_triple.node.uuid, # type: ignore[arg-type] link_triple.link_type, link_triple.link_label, ) diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index b614b5723a..57b54aaace 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -39,7 +39,7 @@ from .repository import NodeRepository if TYPE_CHECKING: - from aiida.plugins.entry_point import EntryPoint # type: ignore + from aiida.plugins.entry_point import EntryPoint # type: ignore[attr-defined] from ..implementation import BackendNode, StorageBackend @@ -52,7 +52,7 @@ class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): """The collection of nodes.""" @staticmethod - def _entity_base_cls() -> Type['Node']: # type: ignore + def _entity_base_cls() -> Type['Node']: # type: ignore[override] return Node def delete(self, pk: int) -> None: diff --git a/aiida/orm/nodes/process/calculation/calcfunction.py b/aiida/orm/nodes/process/calculation/calcfunction.py index 818fec3d06..773b895eb7 100644 --- a/aiida/orm/nodes/process/calculation/calcfunction.py +++ b/aiida/orm/nodes/process/calculation/calcfunction.py @@ -47,7 +47,7 @@ def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str ) -class CalcFunctionNode(FunctionCalculationMixin, CalculationNode): # type: ignore +class CalcFunctionNode(FunctionCalculationMixin, CalculationNode): # type: ignore[misc] """ORM class for all nodes representing the execution of a calcfunction.""" _CLS_NODE_LINKS = CalcFunctionNodeLinks diff --git a/aiida/orm/nodes/process/workflow/workchain.py b/aiida/orm/nodes/process/workflow/workchain.py index 0a673431c1..eba864a25c 100644 --- a/aiida/orm/nodes/process/workflow/workchain.py +++ b/aiida/orm/nodes/process/workflow/workchain.py @@ -23,7 +23,7 @@ class WorkChainNode(WorkflowNode): STEPPER_STATE_INFO_KEY = 'stepper_state_info' @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore + def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore[override] # pylint: disable=no-self-argument return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,) diff --git a/aiida/orm/nodes/process/workflow/workfunction.py b/aiida/orm/nodes/process/workflow/workfunction.py index 73d37c0ab2..2a28b3274e 100644 --- a/aiida/orm/nodes/process/workflow/workfunction.py +++ b/aiida/orm/nodes/process/workflow/workfunction.py @@ -45,7 +45,7 @@ def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str ) -class WorkFunctionNode(FunctionCalculationMixin, WorkflowNode): # type: ignore +class WorkFunctionNode(FunctionCalculationMixin, WorkflowNode): # type: ignore[misc] """ORM class for all nodes representing the execution of a workfunction.""" _CLS_NODE_LINKS = WorkFunctionNodeLinks diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py index ccc814e20b..f64cf8395f 100644 --- a/aiida/orm/nodes/repository.py +++ b/aiida/orm/nodes/repository.py @@ -235,11 +235,11 @@ def put_object_from_filelike(self, handle: io.BufferedReader, path: str): """ self._check_mutability() - if isinstance(handle, io.StringIO): - handle = io.BytesIO(handle.read().encode('utf-8')) + if isinstance(handle, io.StringIO): # type: ignore[unreachable] + handle = io.BytesIO(handle.read().encode('utf-8')) # type: ignore[unreachable] - if isinstance(handle, tempfile._TemporaryFileWrapper): # pylint: disable=protected-access - if 'b' in handle.file.mode: + if isinstance(handle, tempfile._TemporaryFileWrapper): # type: ignore[unreachable] # pylint: disable=protected-access + if 'b' in handle.file.mode: # type: ignore[unreachable] handle = io.BytesIO(handle.read()) else: handle = io.BytesIO(handle.read().encode('utf-8')) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 9e5fa36ff7..7379443a52 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -248,7 +248,7 @@ def __str__(self) -> str: def __deepcopy__(self, memo) -> 'QueryBuilder': """Create deep copy of the instance.""" - return type(self)(backend=self.backend, **self.as_dict()) # type: ignore + return type(self)(backend=self.backend, **self.as_dict()) # type: ignore[arg-type] def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: """Returns a list of all the vertices that are being used. @@ -568,10 +568,10 @@ def append( tag=tag, # for the first item joining_keyword/joining_value can be None, # but after they always default to 'with_incoming' of the previous item - joining_keyword=joining_keyword, # type: ignore - joining_value=joining_value, # type: ignore + joining_keyword=joining_keyword, # type: ignore[typeddict-item] + joining_value=joining_value, # type: ignore[typeddict-item] # same for edge_tag for which a default is applied - edge_tag=edge_tag, # type: ignore + edge_tag=edge_tag, # type: ignore[typeddict-item] outerjoin=outerjoin, ) ) @@ -830,7 +830,7 @@ def add_projection(self, tag_spec: Union[str, EntityClsType], projection_spec: P _projections = [] LOGGER.debug('Adding projection of %s: %s', tag_spec, projection_spec) if not isinstance(projection_spec, (list, tuple)): - projection_spec = [projection_spec] # type: ignore + projection_spec = [projection_spec] # type: ignore[list-item] for projection in projection_spec: if isinstance(projection, dict): _thisprojection = projection @@ -864,7 +864,7 @@ def set_debug(self, debug: bool) -> 'QueryBuilder': '`QueryBuilder.set_debug` is deprecated. Configure the log level of the AiiDA logger instead.', version=3 ) if not isinstance(debug, bool): - return TypeError('I expect a boolean') + raise TypeError('I expect a boolean') self._debug = debug return self @@ -1194,12 +1194,12 @@ def _get_ormclass( func = _get_ormclass_from_cls input_info = cls elif entity_type is not None: - func = _get_ormclass_from_str # type: ignore - input_info = entity_type # type: ignore + func = _get_ormclass_from_str # type: ignore[assignment] + input_info = entity_type # type: ignore[assignment] else: raise ValueError('Neither cls nor entity_type specified') - if isinstance(input_info, str) or not isinstance(input_info, Sequence): + if isinstance(input_info, str) or not isinstance(input_info, Sequence): # type: ignore[redundant-expr] input_info = (input_info,) ormclass = EntityTypes.NODE diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py index 7334f7d632..2106a1d57c 100644 --- a/aiida/orm/utils/links.py +++ b/aiida/orm/utils/links.py @@ -159,7 +159,7 @@ def validate_link( f'source and target nodes must be stored in the same backend, but got {source.backend} and {target.backend}' ) - if source.uuid is None or target.uuid is None: + if source.uuid is None or target.uuid is None: # type: ignore[redundant-expr] raise ValueError('source or target node does not have a UUID') if source.uuid == target.uuid: diff --git a/aiida/parsers/parser.py b/aiida/parsers/parser.py index c4d4976063..92e988a778 100644 --- a/aiida/parsers/parser.py +++ b/aiida/parsers/parser.py @@ -67,7 +67,7 @@ def exit_codes(self) -> ExitCodesNamespace: @property def retrieved(self) -> 'FolderData': return self.node.base.links.get_outgoing().get_node_by_label( - self.node.process_class.link_label_retrieved # type: ignore + self.node.process_class.link_label_retrieved # type: ignore[attr-defined, return-value] ) @property @@ -159,7 +159,7 @@ def parse_calcfunction(**kwargs): # `parse_from_node` method will get an empty dictionary as a result, despite the `Parser.parse` method # having registered outputs. process = Process.current() - process.out_many(outputs) # type: ignore + process.out_many(outputs) # type: ignore[union-attr] return exit_code return dict(outputs) @@ -167,7 +167,7 @@ def parse_calcfunction(**kwargs): inputs = {'metadata': {'store_provenance': store_provenance}} inputs.update(parser.get_outputs_for_parsing()) - return parse_calcfunction.run_get_node(**inputs) # type: ignore + return parse_calcfunction.run_get_node(**inputs) # type: ignore[attr-defined] @abstractmethod def parse(self, **kwargs) -> Optional[ExitCode]: diff --git a/aiida/parsers/plugins/templatereplacer/parser.py b/aiida/parsers/plugins/templatereplacer/parser.py index fa5d0acf0f..2260a06ba8 100644 --- a/aiida/parsers/plugins/templatereplacer/parser.py +++ b/aiida/parsers/plugins/templatereplacer/parser.py @@ -63,7 +63,7 @@ def parse(self, **kwargs): # We always strip the content of the file from whitespace to simplify testing for expected output output_dict['retrieved_temporary_files'].append((retrieved_file, parsed_value)) - label = self.node.process_class.spec().default_output_node # type: ignore + label = self.node.process_class.spec().default_output_node # type: ignore[attr-defined] self.out(label, Dict(dict=output_dict)) return diff --git a/aiida/repository/backend/abstract.py b/aiida/repository/backend/abstract.py index c19f8629b1..5161889f1a 100644 --- a/aiida/repository/backend/abstract.py +++ b/aiida/repository/backend/abstract.py @@ -75,7 +75,10 @@ def put_object_from_filelike(self, handle: BinaryIO) -> str: :return: the generated fully qualified identifier for the object within the repository. :raises TypeError: if the handle is not a byte stream. """ - if not isinstance(handle, io.BufferedIOBase) and not self.is_readable_byte_stream(handle): + if ( + not isinstance(handle, io.BufferedIOBase) and # type: ignore[redundant-expr,unreachable] + not self.is_readable_byte_stream(handle) + ): raise TypeError(f'handle does not seem to be a byte stream: {type(handle)}.') return self._put_object_from_filelike(handle) @@ -143,7 +146,7 @@ def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: """ @contextlib.contextmanager - def open(self, key: str) -> Iterator[BinaryIO]: # type: ignore + def open(self, key: str) -> Iterator[BinaryIO]: # type: ignore[return] """Open a file handle to an object stored under the given key. .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method diff --git a/aiida/schedulers/datastructures.py b/aiida/schedulers/datastructures.py index 6291a90b77..9cd611aa43 100644 --- a/aiida/schedulers/datastructures.py +++ b/aiida/schedulers/datastructures.py @@ -378,7 +378,7 @@ class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-insta ) if TYPE_CHECKING: - shebang: str + shebang: str | None submit_as_hold: bool rerunnable: bool job_environment: dict[str, str] | None @@ -388,8 +388,8 @@ class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-insta email_on_started: bool email_on_terminated: bool job_name: str - sched_output_path: str - sched_error_path: str + sched_output_path: str | None + sched_error_path: str | None sched_join_files: bool queue_name: str account: str diff --git a/aiida/storage/psql_dos/alembic_cli.py b/aiida/storage/psql_dos/alembic_cli.py index 5f37926d2d..a288ce1aac 100755 --- a/aiida/storage/psql_dos/alembic_cli.py +++ b/aiida/storage/psql_dos/alembic_cli.py @@ -9,6 +9,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Simple wrapper around the alembic command line tool that first loads an AiiDA profile.""" +from __future__ import annotations + import alembic import click from sqlalchemy.util.compat import nullcontext @@ -16,6 +18,7 @@ from aiida.cmdline import is_verbose from aiida.cmdline.groups.verdi import VerdiCommandGroup from aiida.cmdline.params import options +from aiida.manage.configuration import Profile from aiida.storage.psql_dos.migrator import PsqlDosMigrator @@ -23,7 +26,7 @@ class AlembicRunner: """Wrapper around the alembic command line tool that first loads an AiiDA profile.""" def __init__(self) -> None: - self.profile = None + self.profile: Profile | None = None def execute_alembic_command(self, command_name, connect=True, **kwargs): """Execute an Alembic CLI command. @@ -36,7 +39,7 @@ def execute_alembic_command(self, command_name, connect=True, **kwargs): migrator = PsqlDosMigrator(self.profile) context = migrator._alembic_connect() if connect else nullcontext(migrator._alembic_config()) # pylint: disable=protected-access - with context as config: + with context as config: # type: ignore[attr-defined] command = getattr(alembic.command, command_name) config.stdout = click.get_text_stream('stdout') command(config, **kwargs) diff --git a/aiida/storage/psql_dos/backend.py b/aiida/storage/psql_dos/backend.py index e236d81021..50b6af6ed0 100644 --- a/aiida/storage/psql_dos/backend.py +++ b/aiida/storage/psql_dos/backend.py @@ -110,7 +110,7 @@ def __init__(self, profile: Profile) -> None: self._session_factory: Optional[scoped_session] = None self._initialise_session() # save the URL of the database, for use in the __str__ method - self._db_url = self.get_session().get_bind().url # type: ignore + self._db_url = self.get_session().get_bind().url # type: ignore[union-attr] self._authinfos = authinfos.SqlaAuthInfoCollection(self) self._comments = comments.SqlaCommentCollection(self) @@ -139,7 +139,7 @@ def _initialise_session(self): Although, in the future, we may want to move the multi-thread handling to higher in the AiiDA stack. """ from aiida.storage.psql_dos.utils import create_sqlalchemy_engine - engine = create_sqlalchemy_engine(self._profile.storage_config) # type: ignore + engine = create_sqlalchemy_engine(self._profile.storage_config) # type: ignore[arg-type] self._session_factory = scoped_session(sessionmaker(bind=engine, future=True, expire_on_commit=True)) def get_session(self) -> Session: @@ -155,7 +155,7 @@ def close(self) -> None: # pylint: disable=no-member engine = self._session_factory.bind if engine is not None: - engine.dispose() # type: ignore + engine.dispose() # type: ignore[union-attr] self._session_factory.expunge_all() self._session_factory.close() self._session_factory = None @@ -379,7 +379,7 @@ def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: if full: maintenance_context = ProfileAccessManager(self._profile).lock else: - maintenance_context = nullcontext # type: ignore + maintenance_context = nullcontext # type: ignore[assignment] with maintenance_context(): unreferenced_objects = self.get_unreferenced_keyset() diff --git a/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py b/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py index dc29cd0a17..24fb526cc2 100644 --- a/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py +++ b/aiida/storage/psql_dos/migrations/utils/create_dbattribute.py @@ -87,7 +87,7 @@ def create_rows(key: str, value, node_id: int) -> list[dict]: # pylint: disable columns['ival'] = len(value) for subk, subv in value.items(): - if not isinstance(key, str) or not key: + if not isinstance(key, str) or not key: # type: ignore[redundant-expr] raise ValidationError('The key must be a non-empty string.') if '.' in key: raise ValidationError( diff --git a/aiida/tools/archive/create.py b/aiida/tools/archive/create.py index 611be240f3..74b323add7 100644 --- a/aiida/tools/archive/create.py +++ b/aiida/tools/archive/create.py @@ -220,6 +220,9 @@ def create_archive( ) else: for entry in entities: + if entry.pk is None or entry.uuid is None: + continue + if isinstance(entry, orm.Group): starting_uuids[EntityTypes.GROUP].add(entry.uuid) entity_ids[EntityTypes.GROUP].add(entry.pk) @@ -365,7 +368,10 @@ def _collect_all_entities( progress.set_description_str(progress_str('Nodes')) entity_ids[EntityTypes.NODE].update( - querybuilder().append(orm.Node, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append(orm.Node, + project='id').all( # type: ignore[arg-type] + batch_size=batch_size, flat=True + ) ) progress.update() @@ -379,45 +385,63 @@ def _collect_all_entities( progress.set_description_str(progress_str('Groups')) progress.update() entity_ids[EntityTypes.GROUP].update( - querybuilder().append(orm.Group, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.Group, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) progress.set_description_str(progress_str('Nodes-Groups')) progress.update() qbuilder = querybuilder().append(orm.Group, project='id', tag='group').append(orm.Node, with_group='group', project='id').distinct() - group_nodes: List[Tuple[int, int]] = qbuilder.all(batch_size=batch_size) # type: ignore + group_nodes: List[Tuple[int, int]] = qbuilder.all(batch_size=batch_size) # type: ignore[assignment] progress.set_description_str(progress_str('Computers')) progress.update() entity_ids[EntityTypes.COMPUTER].update( - querybuilder().append(orm.Computer, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.Computer, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) progress.set_description_str(progress_str('AuthInfos')) progress.update() if include_authinfos: entity_ids[EntityTypes.AUTHINFO].update( - querybuilder().append(orm.AuthInfo, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.AuthInfo, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) progress.set_description_str(progress_str('Logs')) progress.update() if include_logs: entity_ids[EntityTypes.LOG].update( - querybuilder().append(orm.Log, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.Log, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) progress.set_description_str(progress_str('Comments')) progress.update() if include_comments: entity_ids[EntityTypes.COMMENT].update( - querybuilder().append(orm.Comment, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.Comment, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) progress.set_description_str(progress_str('Users')) progress.update() entity_ids[EntityTypes.USER].update( - querybuilder().append(orm.User, project='id').all(batch_size=batch_size, flat=True) # type: ignore + querybuilder().append( + orm.User, + project='id' # type: ignore[arg-type] + ).all(batch_size=batch_size, flat=True) ) return group_nodes, link_data @@ -446,7 +470,7 @@ def _collect_required_entities( ) qbuilder.append(orm.Node, with_group='group', project='id') qbuilder.distinct() - group_nodes = qbuilder.all(batch_size=batch_size) # type: ignore + group_nodes = qbuilder.all(batch_size=batch_size) # type: ignore[assignment] entity_ids[EntityTypes.NODE].update(nid for _, nid in group_nodes) # get full set of nodes & links, following traversal rules @@ -581,7 +605,7 @@ def _stream_repo_files( f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}' ) with get_progress_reporter()(desc='Archiving files: ', total=len(keys)) as progress: - for key, stream in repository.iter_object_streams(keys): # type: ignore + for key, stream in repository.iter_object_streams(keys): # type: ignore[arg-type] # to-do should we use assume the key here is correct, or always re-compute and check? writer.put_object(stream, key=key) progress.update() @@ -624,13 +648,13 @@ def _check_node_licenses( def _check_allowed(name): try: - return allowed_licenses(name) # type: ignore + return allowed_licenses(name) # type: ignore[misc, operator] except Exception as exc: raise LicensingException('allowed_licenses function error') from exc check_allowed = _check_allowed elif isinstance(allowed_licenses, Sequence): - check_allowed = lambda l: l in allowed_licenses # type: ignore + check_allowed = lambda l: l in allowed_licenses # type: ignore[operator] else: raise TypeError('allowed_licenses not a list or function') @@ -641,13 +665,13 @@ def _check_allowed(name): def _check_forbidden(name): try: - return forbidden_licenses(name) # type: ignore + return forbidden_licenses(name) # type: ignore[misc, operator] except Exception as exc: raise LicensingException('forbidden_licenses function error') from exc check_forbidden = _check_forbidden elif isinstance(forbidden_licenses, Sequence): - check_forbidden = lambda l: l in forbidden_licenses # type: ignore + check_forbidden = lambda l: l in forbidden_licenses # type: ignore[operator] else: raise TypeError('forbidden_licenses not a list or function') diff --git a/aiida/tools/archive/implementations/sqlite_zip/writer.py b/aiida/tools/archive/implementations/sqlite_zip/writer.py index 9be155faeb..493c28f2c1 100644 --- a/aiida/tools/archive/implementations/sqlite_zip/writer.py +++ b/aiida/tools/archive/implementations/sqlite_zip/writer.py @@ -108,7 +108,7 @@ def __exit__(self, *args, **kwargs): if self._zip_path: self._zip_path.close() self._central_dir = {} - if self._work_dir is not None and self._init_work_dir is None: + if self._init_work_dir is None: shutil.rmtree(self._work_dir, ignore_errors=True) self._zip_path = self._work_dir = self._conn = None self._in_context = False diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index 994288d476..61dc349e8c 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -256,7 +256,7 @@ def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, # get matching emails from the backend output_email_id: Dict[str, int] = {} if input_id_email: - output_email_id = dict( # type: ignore + output_email_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend_to ).append(orm.User, filters={ @@ -295,7 +295,7 @@ def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, # get matching uuids from the backend backend_uuid_id: Dict[str, int] = {} if input_id_uuid: - backend_uuid_id = dict( # type: ignore + backend_uuid_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend_to ).append(orm.Computer, filters={ @@ -452,7 +452,7 @@ def _import_nodes( # get matching uuids from the backend backend_uuid_id: Dict[str, int] = {} if input_id_uuid: - backend_uuid_id = dict( # type: ignore + backend_uuid_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend_to ).append(orm.Node, filters={ @@ -533,7 +533,7 @@ def _import_logs( # get matching uuids from the backend backend_uuid_id: Dict[str, int] = {} if input_id_uuid: - backend_uuid_id = dict( # type: ignore + backend_uuid_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend_to ).append(orm.Log, filters={ @@ -736,7 +736,7 @@ def _import_comments( # get matching uuids from the backend backend_uuid_id: Dict[str, int] = {} if input_id_uuid: - backend_uuid_id = dict( # type: ignore + backend_uuid_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend ).append(orm.Comment, filters={ @@ -991,7 +991,7 @@ def _import_groups( # get matching uuids from the backend backend_uuid_id: Dict[str, int] = {} if input_id_uuid: - backend_uuid_id = dict( # type: ignore + backend_uuid_id = dict( # type: ignore[assignment] orm.QueryBuilder( backend=backend_to ).append(orm.Group, filters={ @@ -1097,7 +1097,7 @@ def _make_import_group( IMPORT_LOGGER.report(f'Created new import Group: PK={group_id}, label={label}') group_node_ids = set() else: - group_id = group.pk + group_id = group.pk # type: ignore[assignment] IMPORT_LOGGER.report(f'Using existing import Group: PK={group_id}, label={group.label}') group_node_ids = { pk for pk, in orm.QueryBuilder(backend=backend_to).append(orm.Group, filters={ @@ -1158,7 +1158,7 @@ def _add_files_to_repo(backend_from: StorageBackend, backend_to: StorageBackend, repository_to = backend_to.get_repository() repository_from = backend_from.get_repository() with get_progress_reporter()(desc='Adding archive files to repository', total=len(new_keys)) as progress: - for key, handle in repository_from.iter_object_streams(new_keys): # type: ignore + for key, handle in repository_from.iter_object_streams(new_keys): # type: ignore[arg-type] backend_key = repository_to.put_object_from_filelike(handle) if backend_key != key: raise ImportValidationError( diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index cecd10ee5e..36c6da8e1b 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -451,6 +451,7 @@ def add_node( style = {} if style_override is None else dict(style_override) style.update(self._global_node_style) if node.pk not in self._nodes or overwrite: + assert node.pk is not None _add_graphviz_node( self._graph, node, @@ -507,7 +508,7 @@ def _convert_link_types( LinkType.CALL_WORK ] elif isinstance(link_types, (str, LinkType)): - link_types_list = [link_types] # type: ignore + link_types_list = [link_types] # type: ignore[assignment] else: link_types_list = link_types return tuple(getattr(LinkType, l.upper()) if isinstance(l, str) else l for l in link_types_list) @@ -534,6 +535,7 @@ def add_incoming( # incoming nodes are found traversing backwards node_pk = self._load_node(node).pk + assert node_pk is not None valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (node_pk,), @@ -592,6 +594,7 @@ def add_outgoing( # outgoing nodes are found traversing forwards node_pk = self._load_node(node).pk + assert node_pk is not None valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (node_pk,), @@ -654,6 +657,7 @@ def recurse_descendants( # Get graph traversal rules where the given link types and direction are all set to True, # and all others are set to False origin_pk = self._load_node(origin).pk + assert origin_pk is not None valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (origin_pk,), @@ -739,6 +743,7 @@ def recurse_ancestors( # Get graph traversal rules where the given link types and direction are all set to True, # and all others are set to False origin_pk = self._load_node(origin).pk + assert origin_pk is not None valid_link_types = self._convert_link_types(link_types) traversed_graph = traverse_graph( (origin_pk,), diff --git a/pyproject.toml b/pyproject.toml index aebcba57a3..e2f314e8e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -324,7 +324,8 @@ max-locals = 20 [tool.pytest.ini_options] minversion = "7.0" -addopts = "--benchmark-skip --durations=50 --cov-report xml --cov-append " +xfail_strict = true +addopts = "--benchmark-skip --durations=50 --strict-config --strict-markers -ra --cov-report xml --cov-append " testpaths = [ "tests", ] @@ -363,6 +364,11 @@ indent_dictionary_value = false allow_split_before_dict_value = false [tool.mypy] +enable_error_code = [ + "ignore-without-code", + "redundant-expr", + "truthy-bool" +] show_error_codes = true scripts_are_modules = true show_traceback = true @@ -377,6 +383,7 @@ disallow_incomplete_defs = false warn_return_any = false disallow_any_generics = false disallow_subclassing_any = false +warn_unreachable = true [[tool.mypy.overrides]] module = 'aiida' diff --git a/tests/cmdline/commands/test_calcjob.py b/tests/cmdline/commands/test_calcjob.py index da67773446..5eefaee5ee 100644 --- a/tests/cmdline/commands/test_calcjob.py +++ b/tests/cmdline/commands/test_calcjob.py @@ -108,7 +108,7 @@ def init_profile(self, aiida_profile_clean, aiida_localhost, tmp_path): # pylin # Get the imported ArithmeticAddCalculation node ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') calculations = orm.QueryBuilder().append(ArithmeticAddCalculation).all()[0] - self.arithmetic_job: orm.CalcJobNode = calculations[0] # type: ignore + self.arithmetic_job: orm.CalcJobNode = calculations[0] # type: ignore[annotation-unchecked] self.cli_runner = CliRunner() diff --git a/tests/cmdline/commands/test_data.py b/tests/cmdline/commands/test_data.py index e57b6d6439..b1c4e1e413 100644 --- a/tests/cmdline/commands/test_data.py +++ b/tests/cmdline/commands/test_data.py @@ -572,7 +572,7 @@ def mock_check_output(options): class TestVerdiDataStructure(DummyVerdiDataListable, DummyVerdiDataExportable): """Test verdi data core.structure.""" - from aiida.orm.nodes.data.structure import has_ase # type: ignore + from aiida.orm.nodes.data.structure import has_ase # type: ignore[misc] @pytest.fixture(autouse=True) def init_profile(self, aiida_profile_clean, aiida_localhost, run_cli_command): # pylint: disable=unused-argument diff --git a/tests/cmdline/commands/test_process.py b/tests/cmdline/commands/test_process.py index c0f8a79388..b0ec942f3a 100644 --- a/tests/cmdline/commands/test_process.py +++ b/tests/cmdline/commands/test_process.py @@ -29,7 +29,7 @@ def await_condition(condition: t.Callable, timeout: int = 1): """Wait for the ``condition`` to evaluate to ``True`` within the ``timeout`` or raise.""" start_time = time.time() - while not condition: # type: ignore + while not condition(): if time.time() - start_time > timeout: raise RuntimeError(f'waiting for {condition} to evaluate to `True` timed out after {timeout} seconds.') diff --git a/tests/engine/test_process_function.py b/tests/engine/test_process_function.py index 3810fbae59..61902a1af9 100644 --- a/tests/engine/test_process_function.py +++ b/tests/engine/test_process_function.py @@ -570,7 +570,7 @@ def function_with_default(data_a=default): def test_multiple_default_serialization(): """Test that Python base type defaults are automatically serialized to the AiiDA node counterpart.""" - @workfunction # type: ignore + @workfunction # type: ignore[misc] def function_with_multiple_defaults(integer: int = 10, string: str = 'default', boolean: bool = False): return {'integer': integer, 'string': string, 'boolean': boolean} diff --git a/tests/plugins/test_entry_point.py b/tests/plugins/test_entry_point.py index 46866bfa06..89c67cd3b5 100644 --- a/tests/plugins/test_entry_point.py +++ b/tests/plugins/test_entry_point.py @@ -14,8 +14,12 @@ from aiida.common.exceptions import MissingEntryPointError, MultipleEntryPointError from aiida.common.warnings import AiidaDeprecationWarning from aiida.plugins import entry_point -from aiida.plugins.entry_point import EntryPoint as EP # type: ignore -from aiida.plugins.entry_point import EntryPoints, get_entry_point, validate_registered_entry_points # type: ignore +from aiida.plugins.entry_point import ( # type: ignore[attr-defined] + EntryPoints, + get_entry_point, + validate_registered_entry_points, +) +from aiida.plugins.entry_point import EntryPoint as EP # type: ignore[attr-defined] def test_validate_registered_entry_points(): diff --git a/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py b/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py index 9808e7668a..842f97d043 100644 --- a/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py +++ b/tests/storage/psql_dos/migrations/django_branch/test_0037_attributes_extras_settings_json.py @@ -70,7 +70,7 @@ def test_attr_extra_migration(perform_migrations: PsqlDosMigrator): 'datatype': 'date', 'dval': datetime.fromisoformat('2022-01-01') })): - kwargs['tval'] = 'test' # type: ignore + kwargs['tval'] = 'test' # type: ignore[index] attr = attr_model(dbnode_id=node.id, key=f'attr_{idx}', **kwargs) session.add(attr) session.commit() @@ -107,7 +107,7 @@ def test_settings_migration(perform_migrations: PsqlDosMigrator): with perform_migrations.session() as session: kwargs: dict - for idx, kwargs in enumerate(( # type: ignore + for idx, kwargs in enumerate(( # type: ignore[assignment] { 'datatype': 'txt', 'tval': 'test' diff --git a/tests/storage/psql_dos/migrations/django_branch/test_0038_data_migration_legacy_job_calculations.py b/tests/storage/psql_dos/migrations/django_branch/test_0038_data_migration_legacy_job_calculations.py index 25b88e9cb4..55ce11af0b 100644 --- a/tests/storage/psql_dos/migrations/django_branch/test_0038_data_migration_legacy_job_calculations.py +++ b/tests/storage/psql_dos/migrations/django_branch/test_0038_data_migration_legacy_job_calculations.py @@ -57,7 +57,7 @@ def test_legacy_jobcalcstate(perform_migrations: PsqlDosMigrator): node_model = perform_migrations.get_current_table('db_dbnode') with perform_migrations.session() as session: for node_id, mapping in nodes.items(): - attributes = session.get(node_model, node_id).attributes # type: ignore + attributes = session.get(node_model, node_id).attributes # type: ignore[union-attr] assert attributes.get('process_state', None) == mapping.process_state assert attributes.get('process_status', None) == mapping.process_status assert attributes.get('exit_status', None) == mapping.exit_status diff --git a/tests/test_conftest.py b/tests/test_conftest.py index 248fc14b65..5c35204cb0 100644 --- a/tests/test_conftest.py +++ b/tests/test_conftest.py @@ -3,7 +3,7 @@ import pytest from aiida.common.exceptions import MissingEntryPointError -from aiida.plugins.entry_point import EntryPoint, get_entry_point, load_entry_point # type: ignore +from aiida.plugins.entry_point import EntryPoint, get_entry_point, load_entry_point # type: ignore[attr-defined] ENTRY_POINT_GROUP = 'aiida.calculations.importers' diff --git a/tests/test_dbimporters.py b/tests/test_dbimporters.py index 19a06b2c90..322959943b 100644 --- a/tests/test_dbimporters.py +++ b/tests/test_dbimporters.py @@ -15,7 +15,7 @@ class TestCodDbImporter: """Test the CodDbImporter class.""" - from aiida.orm.nodes.data.cif import has_pycifrw # type: ignore + from aiida.orm.nodes.data.cif import has_pycifrw # type: ignore[misc] def test_query_construction_1(self): """Test query construction.""" diff --git a/tests/tools/archive/migration/test_v04_to_v05.py b/tests/tools/archive/migration/test_v04_to_v05.py index 18eb9ff7f9..40aa38666f 100644 --- a/tests/tools/archive/migration/test_v04_to_v05.py +++ b/tests/tools/archive/migration/test_v04_to_v05.py @@ -8,7 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test archive file migration from export version 0.4 to 0.5""" -from aiida.storage.sqlite_zip.migrations.legacy import migrate_v4_to_v5 # type: ignore +from aiida.storage.sqlite_zip.migrations.legacy import migrate_v4_to_v5 # type: ignore[attr-defined] def test_migrate_external(migrate_from_func): diff --git a/tests/tools/archive/migration/test_v05_to_v06.py b/tests/tools/archive/migration/test_v05_to_v06.py index 2699a7f614..ecdf2fe633 100644 --- a/tests/tools/archive/migration/test_v05_to_v06.py +++ b/tests/tools/archive/migration/test_v05_to_v06.py @@ -9,7 +9,7 @@ ########################################################################### """Test archive file migration from export version 0.5 to 0.6""" from aiida.storage.psql_dos.migrations.utils.calc_state import STATE_MAPPING -from aiida.storage.sqlite_zip.migrations.legacy import migrate_v5_to_v6 # type: ignore +from aiida.storage.sqlite_zip.migrations.legacy import migrate_v5_to_v6 # type: ignore[attr-defined] from aiida.storage.sqlite_zip.migrations.utils import verify_metadata_version from tests.utils.archives import get_archive_file, read_json_files