Skip to content

Commit

Permalink
Address mypy warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed Sep 1, 2023
1 parent 4306db1 commit 58e964b
Show file tree
Hide file tree
Showing 26 changed files with 85 additions and 60 deletions.
10 changes: 6 additions & 4 deletions aiida/cmdline/commands/cmd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""`verdi config` command."""
from __future__ import annotations

import json
from pathlib import Path
import textwrap
Expand Down Expand Up @@ -40,7 +42,7 @@ def verdi_config_list(ctx, prefix, description: bool):
from aiida.manage.configuration import Config, Profile

config: Config = ctx.obj.config
profile: Profile = ctx.obj.get('profile', None)
profile: Profile | None = ctx.obj.get('profile', None)

if not profile:
echo.echo_warning('no profiles configured: run `verdi setup` to create one')
Expand Down Expand Up @@ -78,7 +80,7 @@ def verdi_config_show(ctx, option):
from aiida.manage.configuration.options import NO_DEFAULT

config: Config = ctx.obj.config
profile: Profile = ctx.obj.profile
profile: Profile | None = ctx.obj.profile

dct = {
'schema': option.schema,
Expand Down Expand Up @@ -124,7 +126,7 @@ def verdi_config_set(ctx, option, value, globally, append, remove):
echo.echo_critical('Cannot flag both append and remove')

config: Config = ctx.obj.config
profile: Profile = ctx.obj.profile
profile: Profile | None = ctx.obj.profile

if option.global_only:
globally = True
Expand Down Expand Up @@ -164,7 +166,7 @@ def verdi_config_unset(ctx, option, globally):
from aiida.manage.configuration import Config, Profile

config: Config = ctx.obj.config
profile: Profile = ctx.obj.profile
profile: Profile | None = ctx.obj.profile

if option.global_only:
globally = True
Expand Down
2 changes: 1 addition & 1 deletion aiida/cmdline/groups/verdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class VerdiCommandGroup(click.Group):
def add_verbosity_option(cmd: click.Command):
"""Apply the ``verbosity`` option to the command, which is common to all ``verdi`` commands."""
# Only apply the option if it hasn't been already added in a previous call.
if cmd is not None and 'verbosity' not in [param.name for param in cmd.params]:
if 'verbosity' not in [param.name for param in cmd.params]:
cmd = options.VERBOSITY()(cmd)

return cmd
Expand Down
3 changes: 1 addition & 2 deletions aiida/engine/daemon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,8 +713,7 @@ def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> No
pidfile.create(os.getpid())

# Configure the logger
loggerconfig = None
loggerconfig = loggerconfig or arbiter.loggerconfig or None
loggerconfig = arbiter.loggerconfig or None
configure_logger(circus_logger, loglevel, logoutput, loggerconfig)

# Main loop
Expand Down
3 changes: 2 additions & 1 deletion aiida/engine/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .processes.builder import ProcessBuilder
from .processes.functions import FunctionProcess
from .processes.process import Process
from .runners import ResultAndPk
from .utils import instantiate_process, is_process_scoped # pylint: disable=no-name-in-module

__all__ = ('run', 'run_get_pk', 'run_get_node', 'submit')
Expand Down Expand Up @@ -60,7 +61,7 @@ def run_get_node(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[
return runner.run_get_node(process, *args, **inputs)


def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], int]:
def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndPk:
"""Run the process with the supplied inputs in a local runner that will block until the process is completed.
:param process: the process class, instance, builder or function to run
Expand Down
4 changes: 2 additions & 2 deletions aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]:
return None

if exit_code is not None and not isinstance(exit_code, ExitCode):
args = (scheduler.__class__.__name__, type(exit_code))
args = (scheduler.__class__.__name__, type(exit_code)) # type: ignore[unreachable]
raise ValueError('`{}.parse_output` returned neither an `ExitCode` nor None, but: {}'.format(*args))

return exit_code
Expand Down Expand Up @@ -797,7 +797,7 @@ def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = Non
break

if exit_code is not None and not isinstance(exit_code, ExitCode):
args = (parser_class.__name__, type(exit_code))
args = (parser_class.__name__, type(exit_code)) # type: ignore[unreachable]
raise ValueError('`{}.parse` returned neither an `ExitCode` nor None, but: {}'.format(*args))

return exit_code
Expand Down
18 changes: 11 additions & 7 deletions aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_stack_size(size: int = 2) -> int: # type: ignore[return]
for size in itertools.count(size, 8): # pylint: disable=redefined-argument-from-local
frame = frame.f_back.f_back.f_back.f_back.f_back.f_back.f_back.f_back # type: ignore[assignment,union-attr]
except AttributeError:
while frame:
while frame: # type: ignore[truthy-bool]
frame = frame.f_back # type: ignore[assignment]
size += 1
return size - 1
Expand Down Expand Up @@ -234,6 +234,7 @@ def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]:
"""
result, node = run_get_node(*args, **kwargs)
assert node.pk is not None
return result, node.pk

@functools.wraps(function)
Expand Down Expand Up @@ -323,10 +324,13 @@ def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['Func
"""
# pylint: disable=too-many-statements
if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin):
if (
not issubclass(node_class, ProcessNode) or # type: ignore[redundant-expr]
not issubclass(node_class, FunctionCalculationMixin) # type: ignore[unreachable]
):
raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`')

signature = inspect.signature(func)
signature = inspect.signature(func) # type: ignore[unreachable]

args: list[str] = []
varargs: str | None = None
Expand Down Expand Up @@ -586,11 +590,11 @@ def run(self) -> 'ExitCode' | None:

result = self._func(*args, **kwargs)

if result is None or isinstance(result, ExitCode):
return result
if result is None or isinstance(result, ExitCode): # type: ignore[redundant-expr]
return result # type: ignore[unreachable]

if isinstance(result, Data):
self.out(self.SINGLE_OUTPUT_LINKNAME, result)
if isinstance(result, Data): # type: ignore[unreachable]
self.out(self.SINGLE_OUTPUT_LINKNAME, result) # type: ignore[unreachable]
elif isinstance(result, collections.abc.Mapping):
for name, value in result.items():
self.out(name, value)
Expand Down
13 changes: 4 additions & 9 deletions aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def metadata(self) -> AttributeDict:
"""
try:
assert self.inputs is not None
return self.inputs.metadata
except (AssertionError, AttributeError):
return AttributeDict()
Expand Down Expand Up @@ -297,7 +296,6 @@ def get_provenance_inputs_iterator(self) -> Iterator[Tuple[str, Union[InputPort,
:rtype: filter
"""
assert self.inputs is not None
return filter(lambda kv: not kv[0].startswith('_'), self.inputs.items())

@override
Expand Down Expand Up @@ -464,7 +462,7 @@ def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None:
self.report(''.join(traceback.format_exception(*exc_info)))

@override
def on_finish(self, result: Union[int, ExitCode], successful: bool) -> None:
def on_finish(self, result: Union[int, ExitCode, None], successful: bool) -> None:
""" Set the finish status on the process node.
:param result: result of the process
Expand Down Expand Up @@ -702,7 +700,6 @@ def _setup_db_record(self) -> None:
In addition, the parent calculation will be setup with a CALL link if applicable and all inputs will be
linked up as well.
"""
assert self.inputs is not None
assert not self.node.is_sealed, 'process node cannot be sealed when setting up the database record'

# Store important process attributes in the node proxy
Expand Down Expand Up @@ -731,9 +728,6 @@ def _setup_version_info(self) -> None:
"""Store relevant plugin version information."""
from aiida.plugins.entry_point import format_entry_point_string

if self.inputs is None:
return

version_info = self.runner.plugin_version_provider.get_version_info(self.__class__)

for key, monitor in self.inputs.get('monitors', {}).items():
Expand Down Expand Up @@ -836,7 +830,6 @@ def _flat_inputs(self) -> Dict[str, Any]:
:return: flat dictionary of parsed inputs
"""
assert self.inputs is not None
inputs = {key: value for key, value in self.inputs.items() if key != self.spec().metadata_key}
return dict(self._flatten_inputs(self.spec().inputs, inputs))

Expand Down Expand Up @@ -890,7 +883,9 @@ def _flatten_inputs(
items.extend(sub_items)
return items

assert (port is None) or (isinstance(port, InputPort) and (port.is_metadata or port.non_db))
assert (port is None) or (
isinstance(port, InputPort) and (port.is_metadata or port.non_db) # type: ignore[redundant-expr]
)
return []

def _flatten_outputs(
Expand Down
6 changes: 4 additions & 2 deletions aiida/engine/processes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def prune_mapping(value):
:param value: A nested mapping of port values.
:return: The same mapping but without any nested namespace that is completely empty.
"""
if isinstance(value, Mapping) and not isinstance(value, Node):
if isinstance(value, Mapping) and not isinstance(value, Node): # type: ignore[unreachable]
result = {}
for key, sub_value in value.items():
pruned = prune_mapping(sub_value)
# If `pruned` is an "empty'ish" mapping and not an instance of `Node`, skip it, otherwise keep it.
if not (isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node)):
if not (
isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node) # type: ignore[unreachable]
):
result[key] = pruned
return result

Expand Down
3 changes: 2 additions & 1 deletion aiida/engine/processes/workchains/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ def _wrap_bare_dict_inputs(self, port_namespace: 'PortNamespace', inputs: Dict[s
continue

port = port_namespace[key]
valid_types = port.valid_type if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,)
valid_types = port.valid_type \
if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) # type: ignore[redundant-expr]

if isinstance(port, PortNamespace):
wrapped[key] = self._wrap_bare_dict_inputs(port, value)
Expand Down
4 changes: 3 additions & 1 deletion aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
###########################################################################
# pylint: disable=global-statement
"""Runners that can run and submit processes."""
from __future__ import annotations

import asyncio
import functools
import logging
Expand Down Expand Up @@ -43,7 +45,7 @@ class ResultAndNode(NamedTuple):

class ResultAndPk(NamedTuple):
result: Dict[str, Any]
pk: int
pk: int | None


TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name
Expand Down
19 changes: 10 additions & 9 deletions aiida/orm/autogroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to manage the autogrouping functionality by ``verdi run``."""
from __future__ import annotations

import re
from typing import List, Optional

from aiida.common import exceptions, timezone
from aiida.common.escaping import escape_for_sql_like, get_regex_pattern_from_sql
Expand Down Expand Up @@ -44,8 +45,8 @@ def __init__(self, backend):
self._backend = backend

self._enabled = False
self._exclude: Optional[List[str]] = None
self._include: Optional[List[str]] = None
self._exclude: list[str] | None = None
self._include: list[str] | None = None

self._group_label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}"
self._group_label = None # Actual group label, set by `get_or_create_group`
Expand All @@ -63,13 +64,13 @@ def disable(self) -> None:
"""Disable the auto-grouping."""
self._enabled = False

def get_exclude(self) -> Optional[List[str]]:
def get_exclude(self) -> list[str] | None:
"""Return the list of classes to exclude from autogrouping.
Returns ``None`` if no exclusion list has been set."""
return self._exclude

def get_include(self) -> Optional[List[str]]:
def get_include(self) -> list[str] | None:
"""Return the list of classes to include in the autogrouping.
Returns ``None`` if no inclusion list has been set."""
Expand All @@ -81,7 +82,7 @@ def get_group_label_prefix(self) -> str:
return self._group_label_prefix

@staticmethod
def validate(strings: Optional[List[str]]):
def validate(strings: list[str] | None):
"""Validate the list of strings passed to set_include and set_exclude."""
if strings is None:
return
Expand All @@ -97,7 +98,7 @@ def validate(strings: Optional[List[str]]):
f"'{string}' has an invalid prefix, must be among: {sorted(valid_prefixes)}"
)

def set_exclude(self, exclude: Optional[List[str]]) -> None:
def set_exclude(self, exclude: list[str] | str | None) -> None:
"""Set the list of classes to exclude in the autogrouping.
:param exclude: a list of valid entry point strings (might contain '%' to be used as
Expand All @@ -112,7 +113,7 @@ def set_exclude(self, exclude: Optional[List[str]]) -> None:
raise exceptions.ValidationError('Cannot both specify exclude and include')
self._exclude = exclude

def set_include(self, include: Optional[List[str]]) -> None:
def set_include(self, include: list[str] | str | None) -> None:
"""Set the list of classes to include in the autogrouping.
:param include: a list of valid entry point strings (might contain '%' to be used as
Expand All @@ -127,7 +128,7 @@ def set_include(self, include: Optional[List[str]]) -> None:
raise exceptions.ValidationError('Cannot both specify exclude and include')
self._include = include

def set_group_label_prefix(self, label_prefix: Optional[str]) -> None:
def set_group_label_prefix(self, label_prefix: str | None) -> None:
"""Set the label of the group to be created (or use a default)."""
if label_prefix is None:
label_prefix = f"Verdi autogroup on {timezone.now().strftime('%Y-%m-%d %H:%M:%S')}"
Expand Down
6 changes: 4 additions & 2 deletions aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module for all common top level AiiDA entity classes and methods"""
from __future__ import annotations

import abc
from enum import Enum
from functools import lru_cache
Expand Down Expand Up @@ -216,7 +218,7 @@ def initialize(self) -> None:
"""

@property
def id(self) -> int: # pylint: disable=invalid-name
def id(self) -> int | None: # pylint: disable=invalid-name
"""Return the id for this entity.
This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance.
Expand All @@ -229,7 +231,7 @@ def id(self) -> int: # pylint: disable=invalid-name
return self._backend_entity.id

@property
def pk(self) -> int:
def pk(self) -> int | None:
"""Return the primary key for this entity.
This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance.
Expand Down
4 changes: 3 additions & 1 deletion aiida/orm/implementation/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Classes and methods for backend non-specific entities"""
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterable, List, Tuple, Type, TypeVar

Expand Down Expand Up @@ -44,7 +46,7 @@ def id(self) -> int: # pylint: disable=invalid-name
"""

@property
def pk(self) -> int:
def pk(self) -> int | None:
"""Return the id for this entity.
This is unique only amongst entities of this type for a particular backend.
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/nodes/data/code/installed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _validate(self):
"""
super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call

if not self.computer:
if not self.computer: # type: ignore[truthy-bool]
raise exceptions.ValidationError('The `computer` is undefined.')

try:
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/nodes/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def put_object_from_filelike(self, handle: io.BufferedReader, path: str):
"""
self._check_mutability()

if isinstance(handle, io.StringIO):
handle = io.BytesIO(handle.read().encode('utf-8'))
if isinstance(handle, io.StringIO): # type: ignore[unreachable]
handle = io.BytesIO(handle.read().encode('utf-8')) # type: ignore[unreachable]

if isinstance(handle, tempfile._TemporaryFileWrapper): # pylint: disable=protected-access
if 'b' in handle.file.mode:
if isinstance(handle, tempfile._TemporaryFileWrapper): # type: ignore[unreachable] # pylint: disable=protected-access
if 'b' in handle.file.mode: # type: ignore[unreachable]
handle = io.BytesIO(handle.read())
else:
handle = io.BytesIO(handle.read().encode('utf-8'))
Expand Down
Loading

0 comments on commit 58e964b

Please sign in to comment.