Skip to content

Commit

Permalink
Add __torch_function__ for methods (pytorch#37091)
Browse files Browse the repository at this point in the history
Summary:
According to pytorch/rfcs#3

From the goals in the RFC:

1. Support subclassing `torch.Tensor` in Python (done here)
2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here)
3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor`
   subclasses (done in pytorch#30730)
4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here)
5. Propagating subclass instances correctly also with operators, using
   views/slices/indexing/etc. (done here)
6. Preserve subclass attributes when using methods or views/slices/indexing. (done here)
7. A way to insert code that operates on both functions and methods uniformly
   (so we can write a single function that overrides all operators). (done here)
8. The ability to give external libraries a way to also define
   functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR)

This PR makes the following changes:

1. Adds the `self` argument to the arg parser.
2. Dispatches on `self` as well if `self` is not `nullptr`.
3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`.
4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`.
5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`.

TODO:

- [x] Sequence Methods
- [x] Docs
- [x] Tests

Closes pytorch#28361

Benchmarks in pytorch#37091 (comment)

Pull Request resolved: pytorch#37091

Reviewed By: ngimel

Differential Revision: D22765678

Pulled By: ezyang

fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
  • Loading branch information
hameerabbasi authored and facebook-github-bot committed Aug 6, 2020
1 parent 92b7347 commit 3d46e02
Show file tree
Hide file tree
Showing 25 changed files with 1,387 additions and 105 deletions.
2 changes: 1 addition & 1 deletion benchmarks/overrides_benchmark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

return args[0] + args[1]
return super().__torch_function__(func, types, args, kwargs)
58 changes: 58 additions & 0 deletions docs/source/notes/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,64 @@ Also see the ``MetadataTensor`` example below for another variation on this
pattern but instead always returns a ``MetadataTensor`` to propagate metadata
through operations in the :mod:`torch` API.

The ``__torch_function__`` protocol is designed for full coverage of the API,
partial coverage may lead to undesirable results, in particular, certain
functions raising a ``TypeError``. This is especially true for subclasses,
where all three of `torch.add`, `torch.Tensor.__add__` and `torch.Tensor.add`
must be covered, even if they return exactly the same result. Failing to do
this may also lead to infinite recursion. If one requires the implementation
of a function from ``torch.Tensor`` subclasses, they must use
``super().__torch_function__`` inside their implementation.


Subclassing ``torch.Tensor``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
As of version 1.7.0, methods and functions applied on ``torch.Tensor`` subclasses
will return subclass instances instead of ``torch.Tensor`` instances::

>>> class SubTensor(torch.Tensor):
... pass
>>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__
'SubTensor'
>>> type(torch.add(SubTensor([0]), torch.Tensor([1]))).__name__
'SubTensor'

If multiple subclasses exist, the lowest one in the hierarchy will be chosen by
default. If there is no unique way to determine such a case, then a
``TypeError`` is raised::

>>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__
'SubTensor2'
>>> type(torch.add(SubTensor2([0]), torch.Tensor([1]))).__name__
'SubTensor2'
>>> torch.add(SubTensor([0]), OtherSubTensor([1]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor]

If one wishes to have a global override for all tensor methods, one can use
``__torch_function__``. Here is an example that logs all function/method
calls::

class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)

However, if one instead wishes to override a method on the Tensor subclass,
there one can do so either by directly overriding the method (by defining
it for a subclass), or by using ``__torch_function__`` and matching with
``func``.

One should be careful within ``__torch_function__`` for subclasses to always
call ``super().__torch_function__(func, ...)`` instead of ``func`` directly,
as was the case before version 1.7.0. Failing to do this may cause ``func``
to recurse back into ``__torch_function__`` and therefore cause infinite
recursion.

Extending :mod:`torch` with a :class:`Tensor` wrapper type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ ignore_errors = True
[mypy-torch.backends.quantized]
ignore_errors = True

[mypy-torch.overrides]
ignore_errors = True

[mypy-caffe2.python.*]
ignore_errors = True

Expand Down
129 changes: 119 additions & 10 deletions test/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import pprint

from torch.testing._internal.common_utils import TestCase
from torch._overrides import (
from torch.overrides import (
handle_torch_function,
has_torch_function,
get_overridable_functions,
get_testing_overrides,
is_tensor_method_or_property
)

Tensor = torch.Tensor
Expand Down Expand Up @@ -210,6 +211,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return NotImplemented
return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs)

class SubTensor2(torch.Tensor):
pass

class SubSubTensor2(SubTensor2):
pass

class SubTensor3(torch.Tensor):
pass

@implements_sub(torch.mean)
def sub_mean(mat):
return 0
Expand Down Expand Up @@ -275,6 +285,16 @@ def sub_diagonal_foo(a, b, c=None):

# The dispatch table for SubDiagonalTensor's __torch_function__ implementation.
HANDLED_FUNCTIONS_TENSOR_LIKE = {}
HANDLED_FUNCTIONS_WRAPPERS = {}

def triggered_wrapper(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
wrapped._triggered = True
return f(*args, **kwargs)

wrapped._triggered = False
return wrapped

def implements_tensor_like(torch_function):
"Register a torch function override for TensorLike"
Expand All @@ -301,8 +321,14 @@ def generate_tensor_like_torch_implementations():
)
assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs))
for func, override in testing_overrides.items():
# decorate the overrides with implements_tensor_like
implements_tensor_like(func)(override)
# decorate the overrides with implements_tensor_like if it's not a
# torch.Tensor method
wrapped = triggered_wrapper(override)
HANDLED_FUNCTIONS_WRAPPERS[func] = wrapped
if is_tensor_method_or_property(func):
implements_sub(func)(wrapped)
else:
implements_tensor_like(func)(wrapped)

generate_tensor_like_torch_implementations()

Expand Down Expand Up @@ -461,21 +487,73 @@ def test_user_implementation_raises(self):
with self.assertRaises(ValueError):
quux(t1)

def test_tensor_subclass_propagation(self):
"""this test exercises the functionality described in
docs/source/notes/extending.rst#subclassing-torchtensor"""
t1 = torch.tensor([5])
t2 = torch.tensor([6])

s1 = SubTensor2([5])
s2 = SubTensor2([6])

ss1 = SubSubTensor2([5])
ss2 = SubSubTensor2([6])

sn1 = SubTensor3([5])
sn2 = SubTensor3([6])

# Check that leaf subclass is kept regardless of order
self.assertTrue(isinstance(s1 + t2, SubTensor2))
self.assertTrue(isinstance(t1 + s2, SubTensor2))
self.assertTrue(isinstance(s1 + s2, SubTensor2))

# Check indexing subclass is kept
self.assertTrue(isinstance(s1[0], SubTensor2))

# Check case for subclass of subclass.
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + s2, SubSubTensor2))
self.assertTrue(isinstance(s1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1 + t2, SubSubTensor2))
self.assertTrue(isinstance(t1 + ss2, SubSubTensor2))
self.assertTrue(isinstance(ss1[0], SubSubTensor2))

# Make sure unrelated class trees are not merged.
with self.assertRaises(TypeError):
s1 + sn2
with self.assertRaises(TypeError):
sn1 + s2


def generate_tensor_like_override_tests(cls):
from torch.testing._internal.generated.annotated_fn_args import annotated_args

def test_generator(func, override):
# If func corresponds to a torch.Tensor method or property.
if is_tensor_method_or_property(func):
# Generate an instance by using SubTensor,
def instance_gen():
return SubTensor([5])
else:
# Otherwise, TensorLike.
def instance_gen():
return TensorLike()

func_args = []
if inspect.isbuiltin(func) and func in annotated_args:
if func in annotated_args:
for arg in annotated_args[func]:
# Guess valid input to aten function based on type of argument
t = arg['simple_type']
if t.endswith('?'):
t = t[:-1]
if t == 'Tensor':
func_args.append(TensorLike())
if arg['name'] == 'self' and is_tensor_method_or_property(func):
func = func.__get__(instance_gen())
continue
func_args.append(instance_gen())
elif t == 'TensorList':
func_args.append([TensorLike(), TensorLike()])
func_args.append([instance_gen(), instance_gen()])
elif t == 'IntArrayRef':
size = arg.get('size', 2)
if size == 1:
Expand Down Expand Up @@ -503,18 +581,49 @@ def test_generator(func, override):
nargs = len(args.args)
if args.defaults is not None:
nargs -= len(args.defaults)
func_args += [TensorLike() for _ in range(nargs)]
func_args = [instance_gen() for _ in range(nargs)]
if args.varargs is not None:
func_args += [TensorLike(), TensorLike()]
func_args += [instance_gen(), instance_gen()]

def test(self):
self.assertEqual(func(*func_args), -1)
ret = func(*func_args)
# ret is None for certain protocols, e.g., `__weakref__` and `__setitem__`
# This is currently the best check but doesn't work for, for example,
# Tensor.__add__ because it redirects to Tensor.add.
if ret is None:
self.assertTrue(HANDLED_FUNCTIONS_WRAPPERS[func]._triggered)
return

self.assertEqual(ret, -1)

return test

for func, override in get_testing_overrides().items():
test_method = test_generator(func, override)
module = func.__module__
if func.__name__ == "__get__":
# __get__ is part of the descriptor protocol.
# https://docs.python.org/3/howto/descriptor.html
# This is used for properties of the form
# torch.Tensor.<property>, with the method __get__
# In this case we get the property name in two ways:

# This case for properties defined in C.
module = getattr(
func.__self__,
"__qualname__",
None
)

# This one for properties defined in Python.
if module is None:
module = "Tensor." + func.__self__.fget.__name__

# Unfortunately I couldn't find a way to unify these two cases
# and there is no way for general descriptors.
elif is_tensor_method_or_property(func):
module = "Tensor"
else:
module = func.__module__
if module:
name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__)
else:
Expand Down
2 changes: 1 addition & 1 deletion test/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import inspect

try:
import mypy.api
import mypy.api # type: ignore
HAVE_MYPY = True
except ImportError:
HAVE_MYPY = False
Expand Down
10 changes: 9 additions & 1 deletion tools/autograd/gen_annotated_fn_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
"""

from .utils import write, CodeTemplate
from .gen_python_functions import get_py_nn_functions, get_py_torch_functions, op_name
from .gen_python_functions import (
get_py_nn_functions,
get_py_torch_functions,
get_py_variable_methods,
op_name,
)
import textwrap
from .gen_autograd import load_aten_declarations

Expand All @@ -28,6 +33,9 @@ def gen_annotated(aten_path, out, template_path):
for func in recurse_dict(get_py_nn_functions(declarations)):
annotated_args.append(process_func("torch._C._nn", func))

for func in recurse_dict(get_py_variable_methods(declarations)):
annotated_args.append(process_func("torch.Tensor", func))

annotated_args = textwrap.indent("\n".join(annotated_args), " ")
env = {"annotated_args": annotated_args}
PY_ANNOTATED_ARGS = CodeTemplate.from_file(template_path + '/templates/annotated_fn_args.py')
Expand Down
26 changes: 22 additions & 4 deletions tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def is_noarg_binding(overloads):
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
switch (_r.idx) {
${dispatch}
Expand All @@ -833,7 +833,7 @@ def is_noarg_binding(overloads):
}, /*traceable=*/${traceable});
ParsedArgs<${max_args}> parsed_args;
auto _r = parser.parse(args, kwargs, parsed_args);
auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
${check_has_torch_function}
${dispatch}
${method_footer}
Expand All @@ -847,6 +847,7 @@ def is_noarg_binding(overloads):
static PyObject * ${pycname}(PyObject* self_, PyObject* args)
{
${method_header}
${check_has_torch_function}
${dispatch}
${method_footer}
}
Expand All @@ -855,7 +856,13 @@ def is_noarg_binding(overloads):

TORCH_FUNCTION_CHECK = CodeTemplate("""\
if(_r.has_torch_function()) {
return handle_torch_function(_r, args, kwargs, ${namespace}, ${modulename});
return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename});
}
""")

TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\
if(check_has_torch_function(self_)) {
return handle_torch_function(self_, ${name});
}
""")

Expand All @@ -881,6 +888,10 @@ def method_impl(name, declarations, is_python_method, module):

method_footer = ['END_HANDLE_TH_ERRORS']

check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute(
name='"' + name + '"',
) if is_python_method else ''

# emit dispatch
if is_noarg_binding(declarations):
dispatch = emit_single_dispatch(declaration, is_python_method)
Expand All @@ -890,6 +901,7 @@ def method_impl(name, declarations, is_python_method, module):
method_header=method_header,
dispatch=dispatch,
method_footer=method_footer,
check_has_torch_function=check_has_torch_function,
)

method_footer = ['Py_RETURN_NONE;'] + method_footer
Expand All @@ -914,9 +926,14 @@ def method_impl(name, declarations, is_python_method, module):
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
namespace=NATIVE_NAMESPACE_MAPPING[module],
modulename='"' + module + '"',
self_="self_" if is_python_method else "nullptr",
)
else:
check_has_torch_function = ''
check_has_torch_function = TORCH_FUNCTION_CHECK.substitute(
namespace="THPVariableClass",
modulename='"torch.Tensor"',
self_="self_" if is_python_method else "nullptr",
)

max_args = max([get_python_argc(decl) for decl in declarations])
traceable = 'true' if all(should_trace(d) for d in declarations) else 'false'
Expand All @@ -931,6 +948,7 @@ def method_impl(name, declarations, is_python_method, module):
check_has_torch_function=check_has_torch_function,
dispatch=dispatch,
method_footer=method_footer,
self_="self_" if is_python_method else "nullptr",
)


Expand Down
Loading

0 comments on commit 3d46e02

Please sign in to comment.