Skip to content

Commit

Permalink
And type annotations for cpp_extension, utils.data, signal_handling (p…
Browse files Browse the repository at this point in the history
…ytorch#42647)

Summary: Pull Request resolved: pytorch#42647

Reviewed By: ezyang

Differential Revision: D22967041

Pulled By: malfet

fbshipit-source-id: 35e124da0be56934faef56834a93b2b400decf66
  • Loading branch information
rgommers authored and facebook-github-bot committed Aug 6, 2020
1 parent 608f99e commit bcab2d6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 29 deletions.
12 changes: 3 additions & 9 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,9 @@ ignore_errors = True
[mypy-torch.contrib._tensorboard_vis]
ignore_errors = True

[mypy-torch.utils.cpp_extension]
ignore_errors = True

[mypy-torch.utils.bottleneck.__main__]
ignore_errors = True

[mypy-torch.utils.data]
ignore_errors = True

[mypy-torch.utils.data._utils.signal_handling]
ignore_errors = True

[mypy-torch.utils.data._utils.collate]
ignore_errors = True

Expand Down Expand Up @@ -448,6 +439,9 @@ ignore_missing_imports = True
[mypy-setuptools.*]
ignore_missing_imports = True

[mypy-distutils.*]
ignore_missing_imports = True

[mypy-nvd3.*]
ignore_missing_imports = True

Expand Down
8 changes: 7 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class _TensorBase(object):
grad_fn: Any
${tensor_method_hints}

# Defined in torch/csrs/cuda/Module.cpp
# Defined in torch/csrc/cuda/Module.cpp
class _CudaDeviceProperties:
name: str
major: _int
Expand Down Expand Up @@ -329,3 +329,9 @@ class _CudaEventBase:
def elapsed_time(self, other: _CudaEventBase) -> _float: ...
def synchronize(self) -> None: ...
def ipc_handle(self) -> bytes: ...

# Defined in torch/csrc/DataLoader.cpp
def _set_worker_signal_handlers(*arg: Any) -> None: ... # THPModule_setWorkerSignalHandlers
def _set_worker_pids(key: _int, child_pids: Tuple[_int, ...]) -> None: ... # THPModule_setWorkerPIDs
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
33 changes: 20 additions & 13 deletions torch/utils/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def _join_rocm_home(*paths):
ROCM_HOME = _find_rocm_home()
MIOPEN_HOME = _join_rocm_home('miopen') if ROCM_HOME else None
IS_HIP_EXTENSION = True if ((ROCM_HOME is not None) and (torch.version.hip is not None)) else False
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip is not None else None
ROCM_VERSION = None
if torch.version.hip is not None:
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split('.')[:2])

CUDA_HOME = _find_cuda_home()
CUDNN_HOME = os.environ.get('CUDNN_HOME') or os.environ.get('CUDNN_PATH')
# PyTorch releases have the version pattern major.minor.patch, whereas when
Expand Down Expand Up @@ -259,8 +262,8 @@ def check_compiler_abi_compatibility(compiler):
try:
if sys.platform.startswith('linux'):
minimum_required_version = MINIMUM_GCC_VERSION
version = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
version = version.decode().strip().split('.')
versionstr = subprocess.check_output([compiler, '-dumpfullversion', '-dumpversion'])
version = versionstr.decode().strip().split('.')
else:
minimum_required_version = MINIMUM_MSVC_VERSION
compiler_info = subprocess.check_output(compiler, stderr=subprocess.STDOUT)
Expand Down Expand Up @@ -316,7 +319,7 @@ def with_options(cls, **options):
Returns a subclass with alternative constructor that extends any original keyword
arguments to the original constructor with the given options.
'''
class cls_with_options(cls):
class cls_with_options(cls): # type: ignore
def __init__(self, *args, **kwargs):
kwargs.update(options)
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -613,7 +616,7 @@ def win_wrap_ninja_compile(sources,
cuda_post_cflags = list(extra_postargs)
cuda_post_cflags = win_cuda_flags(cuda_post_cflags)

from distutils.spawn import _nt_quote_args
from distutils.spawn import _nt_quote_args # type: ignore
cflags = _nt_quote_args(cflags)
post_cflags = _nt_quote_args(post_cflags)
if with_cuda:
Expand Down Expand Up @@ -786,6 +789,7 @@ def CUDAExtension(name, sources, *args, **kwargs):
libraries.append('torch_cpu')
libraries.append('torch_python')
if IS_HIP_EXTENSION:
assert ROCM_VERSION is not None
libraries.append('amdhip64' if ROCM_VERSION >= (3, 5) else 'hip_hcc')
libraries.append('c10_hip')
libraries.append('torch_hip')
Expand Down Expand Up @@ -1352,6 +1356,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose):
if CUDNN_HOME is not None:
extra_ldflags.append('-L{}'.format(os.path.join(CUDNN_HOME, 'lib64')))
elif IS_HIP_EXTENSION:
assert ROCM_VERSION is not None
extra_ldflags.append('-L{}'.format(_join_rocm_home('lib')))
extra_ldflags.append('-lamdhip64' if ROCM_VERSION >= (3, 5) else '-lhip_hcc')
return extra_ldflags
Expand Down Expand Up @@ -1397,20 +1402,20 @@ def _get_cuda_arch_flags(cflags=None):
# First check for an env var (same as used by the main setup.py)
# Can be one or more architectures, e.g. "6.1" or "3.5;5.2;6.0;6.1;7.0+PTX"
# See cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)
_arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST', None)

# If not given, determine what's needed for the GPU that can be found
if not arch_list:
if not _arch_list:
capability = torch.cuda.get_device_capability()
arch_list = ['{}.{}'.format(capability[0], capability[1])]
else:
# Deal with lists that are ' ' separated (only deal with ';' after)
arch_list = arch_list.replace(' ', ';')
_arch_list = _arch_list.replace(' ', ';')
# Expand named arches
for named_arch, archval in named_arches.items():
arch_list = arch_list.replace(named_arch, archval)
_arch_list = _arch_list.replace(named_arch, archval)

arch_list = arch_list.split(';')
arch_list = _arch_list.split(';')

flags = []
for arch in arch_list:
Expand Down Expand Up @@ -1528,8 +1533,10 @@ def _run_ninja_build(build_directory, verbose, error_prefix):
_, error, _ = sys.exc_info()
# error.output contains the stdout and stderr of the build attempt.
message = error_prefix
if hasattr(error, 'output') and error.output:
message += ": {}".format(error.output.decode())
# `error` is a CalledProcessError (which has an `ouput`) attribute, but
# mypy thinks it's Optional[BaseException] and doesn't narrow
if hasattr(error, 'output') and error.output: # type: ignore
message += ": {}".format(error.output.decode()) # type: ignore
raise RuntimeError(message)


Expand Down Expand Up @@ -1580,7 +1587,7 @@ def _write_ninja_file_to_build_library(path,

if IS_WINDOWS:
cflags = common_cflags + COMMON_MSVC_FLAGS + extra_cflags
from distutils.spawn import _nt_quote_args
from distutils.spawn import _nt_quote_args # type: ignore
cflags = _nt_quote_args(cflags)
else:
cflags = common_cflags + ['-fPIC', '-std=c++14'] + extra_cflags
Expand Down
3 changes: 2 additions & 1 deletion torch/utils/data/_utils/signal_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _set_SIGCHLD_handler():
if IS_WINDOWS:
return
# can't set signal in child threads
if not isinstance(threading.current_thread(), threading._MainThread):
if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore
return
global _SIGCHLD_handler_set
if _SIGCHLD_handler_set:
Expand All @@ -65,6 +65,7 @@ def handler(signum, frame):
# Python can still get and update the process status successfully.
_error_if_any_worker_fails()
if previous_handler is not None:
assert callable(previous_handler)
previous_handler(signum, frame)

signal.signal(signal.SIGCHLD, handler)
Expand Down
11 changes: 6 additions & 5 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class DataLoader(Generic[T_co]):
worker_init_fn (callable, optional): If not ``None``, this will be called on each
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
input, after seeding and before data loading. (default: ``None``)
prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
in advance by each worker. ``2`` means there will be a total of
prefetch_factor (int, optional, keyword-only arg): Number of sample loaded
in advance by each worker. ``2`` means there will be a total of
2 * num_workers samples prefetched across all workers. (default: ``2``)
Expand Down Expand Up @@ -152,9 +152,9 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: _collate_fn_t = None,
pin_memory: bool = False, drop_last: bool = False,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: _worker_init_fn_t = None,
multiprocessing_context=None, generator=None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2):
torch._C._log_api_usage_once("python.data_loader") # type: ignore

Expand Down Expand Up @@ -797,7 +797,8 @@ def __init__(self, loader):
else:
self._data_queue = self._worker_result_queue

_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True

Expand Down

0 comments on commit bcab2d6

Please sign in to comment.