diff --git a/benchmarks/overrides_benchmark/common.py b/benchmarks/overrides_benchmark/common.py index 9651c0496a917..e1d6fb3656c9c 100644 --- a/benchmarks/overrides_benchmark/common.py +++ b/benchmarks/overrides_benchmark/common.py @@ -28,4 +28,4 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - return args[0] + args[1] + return super().__torch_function__(func, types, args, kwargs) diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index f0948c28ec19b..f167d705052f9 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -445,6 +445,64 @@ Also see the ``MetadataTensor`` example below for another variation on this pattern but instead always returns a ``MetadataTensor`` to propagate metadata through operations in the :mod:`torch` API. +The ``__torch_function__`` protocol is designed for full coverage of the API, +partial coverage may lead to undesirable results, in particular, certain +functions raising a ``TypeError``. This is especially true for subclasses, +where all three of `torch.add`, `torch.Tensor.__add__` and `torch.Tensor.add` +must be covered, even if they return exactly the same result. Failing to do +this may also lead to infinite recursion. If one requires the implementation +of a function from ``torch.Tensor`` subclasses, they must use +``super().__torch_function__`` inside their implementation. + + +Subclassing ``torch.Tensor`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +As of version 1.7.0, methods and functions applied on ``torch.Tensor`` subclasses +will return subclass instances instead of ``torch.Tensor`` instances:: + + >>> class SubTensor(torch.Tensor): + ... pass + >>> type(torch.add(SubTensor([0]), SubTensor([1]))).__name__ + 'SubTensor' + >>> type(torch.add(SubTensor([0]), torch.Tensor([1]))).__name__ + 'SubTensor' + +If multiple subclasses exist, the lowest one in the hierarchy will be chosen by +default. If there is no unique way to determine such a case, then a +``TypeError`` is raised:: + + >>> type(torch.add(SubTensor2([0]), SubTensor([1]))).__name__ + 'SubTensor2' + >>> type(torch.add(SubTensor2([0]), torch.Tensor([1]))).__name__ + 'SubTensor2' + >>> torch.add(SubTensor([0]), OtherSubTensor([1])) + Traceback (most recent call last): + File "", line 1, in + TypeError: no implementation found for 'torch.add' on types that implement __torch_function__: [SubTensor, OtherSubTensor] + +If one wishes to have a global override for all tensor methods, one can use +``__torch_function__``. Here is an example that logs all function/method +calls:: + + class LoggingTensor(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}") + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + +However, if one instead wishes to override a method on the Tensor subclass, +there one can do so either by directly overriding the method (by defining +it for a subclass), or by using ``__torch_function__`` and matching with +``func``. + +One should be careful within ``__torch_function__`` for subclasses to always +call ``super().__torch_function__(func, ...)`` instead of ``func`` directly, +as was the case before version 1.7.0. Failing to do this may cause ``func`` +to recurse back into ``__torch_function__`` and therefore cause infinite +recursion. + Extending :mod:`torch` with a :class:`Tensor` wrapper type ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/mypy.ini b/mypy.ini index 8519fb2e6f787..0ea9618485340 100644 --- a/mypy.ini +++ b/mypy.ini @@ -366,6 +366,9 @@ ignore_errors = True [mypy-torch.backends.quantized] ignore_errors = True +[mypy-torch.overrides] +ignore_errors = True + [mypy-caffe2.python.*] ignore_errors = True diff --git a/test/test_overrides.py b/test/test_overrides.py index cdc22aea097e6..7d058ea00c04f 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -6,11 +6,12 @@ import pprint from torch.testing._internal.common_utils import TestCase -from torch._overrides import ( +from torch.overrides import ( handle_torch_function, has_torch_function, get_overridable_functions, get_testing_overrides, + is_tensor_method_or_property ) Tensor = torch.Tensor @@ -210,6 +211,15 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return NotImplemented return HANDLED_FUNCTIONS_SUB[func](*args, **kwargs) +class SubTensor2(torch.Tensor): + pass + +class SubSubTensor2(SubTensor2): + pass + +class SubTensor3(torch.Tensor): + pass + @implements_sub(torch.mean) def sub_mean(mat): return 0 @@ -275,6 +285,16 @@ def sub_diagonal_foo(a, b, c=None): # The dispatch table for SubDiagonalTensor's __torch_function__ implementation. HANDLED_FUNCTIONS_TENSOR_LIKE = {} +HANDLED_FUNCTIONS_WRAPPERS = {} + +def triggered_wrapper(f): + @functools.wraps(f) + def wrapped(*args, **kwargs): + wrapped._triggered = True + return f(*args, **kwargs) + + wrapped._triggered = False + return wrapped def implements_tensor_like(torch_function): "Register a torch function override for TensorLike" @@ -301,8 +321,14 @@ def generate_tensor_like_torch_implementations(): ) assert len(untested_funcs) == 0, msg.format(pprint.pformat(untested_funcs)) for func, override in testing_overrides.items(): - # decorate the overrides with implements_tensor_like - implements_tensor_like(func)(override) + # decorate the overrides with implements_tensor_like if it's not a + # torch.Tensor method + wrapped = triggered_wrapper(override) + HANDLED_FUNCTIONS_WRAPPERS[func] = wrapped + if is_tensor_method_or_property(func): + implements_sub(func)(wrapped) + else: + implements_tensor_like(func)(wrapped) generate_tensor_like_torch_implementations() @@ -461,21 +487,73 @@ def test_user_implementation_raises(self): with self.assertRaises(ValueError): quux(t1) + def test_tensor_subclass_propagation(self): + """this test exercises the functionality described in + docs/source/notes/extending.rst#subclassing-torchtensor""" + t1 = torch.tensor([5]) + t2 = torch.tensor([6]) + + s1 = SubTensor2([5]) + s2 = SubTensor2([6]) + + ss1 = SubSubTensor2([5]) + ss2 = SubSubTensor2([6]) + + sn1 = SubTensor3([5]) + sn2 = SubTensor3([6]) + + # Check that leaf subclass is kept regardless of order + self.assertTrue(isinstance(s1 + t2, SubTensor2)) + self.assertTrue(isinstance(t1 + s2, SubTensor2)) + self.assertTrue(isinstance(s1 + s2, SubTensor2)) + + # Check indexing subclass is kept + self.assertTrue(isinstance(s1[0], SubTensor2)) + + # Check case for subclass of subclass. + self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) + self.assertTrue(isinstance(ss1 + s2, SubSubTensor2)) + self.assertTrue(isinstance(s1 + ss2, SubSubTensor2)) + self.assertTrue(isinstance(ss1 + ss2, SubSubTensor2)) + self.assertTrue(isinstance(ss1 + t2, SubSubTensor2)) + self.assertTrue(isinstance(t1 + ss2, SubSubTensor2)) + self.assertTrue(isinstance(ss1[0], SubSubTensor2)) + + # Make sure unrelated class trees are not merged. + with self.assertRaises(TypeError): + s1 + sn2 + with self.assertRaises(TypeError): + sn1 + s2 + + def generate_tensor_like_override_tests(cls): from torch.testing._internal.generated.annotated_fn_args import annotated_args def test_generator(func, override): + # If func corresponds to a torch.Tensor method or property. + if is_tensor_method_or_property(func): + # Generate an instance by using SubTensor, + def instance_gen(): + return SubTensor([5]) + else: + # Otherwise, TensorLike. + def instance_gen(): + return TensorLike() + func_args = [] - if inspect.isbuiltin(func) and func in annotated_args: + if func in annotated_args: for arg in annotated_args[func]: # Guess valid input to aten function based on type of argument t = arg['simple_type'] if t.endswith('?'): t = t[:-1] if t == 'Tensor': - func_args.append(TensorLike()) + if arg['name'] == 'self' and is_tensor_method_or_property(func): + func = func.__get__(instance_gen()) + continue + func_args.append(instance_gen()) elif t == 'TensorList': - func_args.append([TensorLike(), TensorLike()]) + func_args.append([instance_gen(), instance_gen()]) elif t == 'IntArrayRef': size = arg.get('size', 2) if size == 1: @@ -503,18 +581,49 @@ def test_generator(func, override): nargs = len(args.args) if args.defaults is not None: nargs -= len(args.defaults) - func_args += [TensorLike() for _ in range(nargs)] + func_args = [instance_gen() for _ in range(nargs)] if args.varargs is not None: - func_args += [TensorLike(), TensorLike()] + func_args += [instance_gen(), instance_gen()] def test(self): - self.assertEqual(func(*func_args), -1) + ret = func(*func_args) + # ret is None for certain protocols, e.g., `__weakref__` and `__setitem__` + # This is currently the best check but doesn't work for, for example, + # Tensor.__add__ because it redirects to Tensor.add. + if ret is None: + self.assertTrue(HANDLED_FUNCTIONS_WRAPPERS[func]._triggered) + return + + self.assertEqual(ret, -1) return test for func, override in get_testing_overrides().items(): test_method = test_generator(func, override) - module = func.__module__ + if func.__name__ == "__get__": + # __get__ is part of the descriptor protocol. + # https://docs.python.org/3/howto/descriptor.html + # This is used for properties of the form + # torch.Tensor., with the method __get__ + # In this case we get the property name in two ways: + + # This case for properties defined in C. + module = getattr( + func.__self__, + "__qualname__", + None + ) + + # This one for properties defined in Python. + if module is None: + module = "Tensor." + func.__self__.fget.__name__ + + # Unfortunately I couldn't find a way to unify these two cases + # and there is no way for general descriptors. + elif is_tensor_method_or_property(func): + module = "Tensor" + else: + module = func.__module__ if module: name = 'test_{}_{}'.format(module.replace('.', '_'), func.__name__) else: diff --git a/test/test_type_hints.py b/test/test_type_hints.py index 06c1aa005d484..3f6e1215a10b3 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -8,7 +8,7 @@ import inspect try: - import mypy.api + import mypy.api # type: ignore HAVE_MYPY = True except ImportError: HAVE_MYPY = False diff --git a/tools/autograd/gen_annotated_fn_args.py b/tools/autograd/gen_annotated_fn_args.py index 01692e546050d..7b4b0ece8da60 100644 --- a/tools/autograd/gen_annotated_fn_args.py +++ b/tools/autograd/gen_annotated_fn_args.py @@ -14,7 +14,12 @@ """ from .utils import write, CodeTemplate -from .gen_python_functions import get_py_nn_functions, get_py_torch_functions, op_name +from .gen_python_functions import ( + get_py_nn_functions, + get_py_torch_functions, + get_py_variable_methods, + op_name, +) import textwrap from .gen_autograd import load_aten_declarations @@ -28,6 +33,9 @@ def gen_annotated(aten_path, out, template_path): for func in recurse_dict(get_py_nn_functions(declarations)): annotated_args.append(process_func("torch._C._nn", func)) + for func in recurse_dict(get_py_variable_methods(declarations)): + annotated_args.append(process_func("torch.Tensor", func)) + annotated_args = textwrap.indent("\n".join(annotated_args), " ") env = {"annotated_args": annotated_args} PY_ANNOTATED_ARGS = CodeTemplate.from_file(template_path + '/templates/annotated_fn_args.py') diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index de2ac76799906..1e31dfbe1a6d7 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -812,7 +812,7 @@ def is_noarg_binding(overloads): }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; - auto _r = parser.parse(args, kwargs, parsed_args); + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} switch (_r.idx) { ${dispatch} @@ -833,7 +833,7 @@ def is_noarg_binding(overloads): }, /*traceable=*/${traceable}); ParsedArgs<${max_args}> parsed_args; - auto _r = parser.parse(args, kwargs, parsed_args); + auto _r = parser.parse(${self_}, args, kwargs, parsed_args); ${check_has_torch_function} ${dispatch} ${method_footer} @@ -847,6 +847,7 @@ def is_noarg_binding(overloads): static PyObject * ${pycname}(PyObject* self_, PyObject* args) { ${method_header} + ${check_has_torch_function} ${dispatch} ${method_footer} } @@ -855,7 +856,13 @@ def is_noarg_binding(overloads): TORCH_FUNCTION_CHECK = CodeTemplate("""\ if(_r.has_torch_function()) { - return handle_torch_function(_r, args, kwargs, ${namespace}, ${modulename}); + return handle_torch_function(_r, ${self_}, args, kwargs, ${namespace}, ${modulename}); +} +""") + +TORCH_FUNCTION_CHECK_NOARGS = CodeTemplate("""\ +if(check_has_torch_function(self_)) { + return handle_torch_function(self_, ${name}); } """) @@ -881,6 +888,10 @@ def method_impl(name, declarations, is_python_method, module): method_footer = ['END_HANDLE_TH_ERRORS'] + check_has_torch_function = TORCH_FUNCTION_CHECK_NOARGS.substitute( + name='"' + name + '"', + ) if is_python_method else '' + # emit dispatch if is_noarg_binding(declarations): dispatch = emit_single_dispatch(declaration, is_python_method) @@ -890,6 +901,7 @@ def method_impl(name, declarations, is_python_method, module): method_header=method_header, dispatch=dispatch, method_footer=method_footer, + check_has_torch_function=check_has_torch_function, ) method_footer = ['Py_RETURN_NONE;'] + method_footer @@ -914,9 +926,14 @@ def method_impl(name, declarations, is_python_method, module): check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( namespace=NATIVE_NAMESPACE_MAPPING[module], modulename='"' + module + '"', + self_="self_" if is_python_method else "nullptr", ) else: - check_has_torch_function = '' + check_has_torch_function = TORCH_FUNCTION_CHECK.substitute( + namespace="THPVariableClass", + modulename='"torch.Tensor"', + self_="self_" if is_python_method else "nullptr", + ) max_args = max([get_python_argc(decl) for decl in declarations]) traceable = 'true' if all(should_trace(d) for d in declarations) else 'false' @@ -931,6 +948,7 @@ def method_impl(name, declarations, is_python_method, module): check_has_torch_function=check_has_torch_function, dispatch=dispatch, method_footer=method_footer, + self_="self_" if is_python_method else "nullptr", ) diff --git a/tools/autograd/templates/python_nn_functions.cpp b/tools/autograd/templates/python_nn_functions.cpp index 4c1cf2d19330c..e60de17790251 100644 --- a/tools/autograd/templates/python_nn_functions.cpp +++ b/tools/autograd/templates/python_nn_functions.cpp @@ -21,10 +21,22 @@ using namespace torch::autograd::utils; namespace torch { namespace autograd { +static PyObject* THPNNVariableFunctionsModule = NULL; + static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS - auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ false); // we don't want copy for nn.Module.to + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to auto& device = std::get<0>(parsed); auto& scalarType = std::get<1>(parsed); auto non_blocking = std::get<2>(parsed); @@ -64,8 +76,6 @@ static PyMethodDef nn_functions[] = { {NULL} }; -static PyObject* THPNNVariableFunctionsModule = NULL; - void initNNFunctions(PyObject* module) { static struct PyModuleDef def = { PyModuleDef_HEAD_INIT, diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 39fc46cc1f1c7..731e5fd6099e5 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -47,6 +47,9 @@ namespace torch { namespace autograd { static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "_is_view"); + } auto& self_ = reinterpret_cast(self)->cdata; if (self_.is_view()) { Py_RETURN_TRUE; @@ -61,6 +64,10 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + auto args = py::make_tuple(py::handle(arg)); + return handle_torch_function(self, "apply_", args.ptr()); + } auto& self_ = reinterpret_cast(self)->cdata; if (self_.requires_grad()) { throw std::runtime_error( @@ -81,7 +88,12 @@ static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwa }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.idx == 0) { if (jit::tracer::isTracing()) { return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0))); @@ -113,7 +125,12 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.idx == 0) { return wrap(self_.stride(r.toInt64(0))); } else if (r.idx == 1) { @@ -134,6 +151,9 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "get_device"); + } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.get_device()); END_HANDLE_TH_ERRORS @@ -142,6 +162,9 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "has_names"); + } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.has_names()); END_HANDLE_TH_ERRORS @@ -151,6 +174,9 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "data_ptr"); + } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.data_ptr()); END_HANDLE_TH_ERRORS @@ -160,6 +186,9 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "storage_offset"); + } auto& self = reinterpret_cast(self_)->cdata; return wrap(self.storage_offset()); END_HANDLE_TH_ERRORS @@ -169,6 +198,9 @@ static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) static PyObject * THPVariable_dim(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "dim"); + } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.dim()); END_HANDLE_TH_ERRORS @@ -178,6 +210,9 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) static PyObject * THPVariable_numel(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "numel"); + } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.numel()); END_HANDLE_TH_ERRORS @@ -196,7 +231,12 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObjec "contiguous(*, MemoryFormat memory_format=contiguous_format)", }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto& self_ = reinterpret_cast(self)->cdata; auto memory_format = r.memoryformat(0); // avoids touching the GIL or current device if self is already contiguous @@ -236,7 +276,12 @@ static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool non_block }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1))); END_HANDLE_TH_ERRORS } @@ -279,6 +324,14 @@ static bool dispatch_to_Bool(const Tensor & self) { static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + try { + return handle_torch_function(self, "__bool__"); + } + catch(const python_error&) { + return nullptr; + } + } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; return wrap(dispatch_to_CDouble(self_)); @@ -287,6 +340,9 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__int__"); + } jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; if (isFloatingType(self_.scalar_type())) { @@ -303,6 +359,9 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { // called when used as a slice. static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__index__"); + } jit::tracer::warn("Converting a tensor to a Python index", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; // TODO: change the condition to `self_.dim() != 0` once we expose scalars @@ -322,6 +381,9 @@ static Tensor dispatch_invert(const Tensor & self) { static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "__float__"); + } auto& self_ = reinterpret_cast(self)->cdata; if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) { throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors"); @@ -365,7 +427,12 @@ static PyObject * THPVariable_cpu(PyObject* self, PyObject* args, PyObject* kwar }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false, opt_memory_format)); END_HANDLE_TH_ERRORS @@ -392,7 +459,12 @@ static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.idx == 0 || (r.idx == 1 && !r.toBool(0))) { return wrap(dispatch_nonzero(self_)); } else { @@ -410,7 +482,12 @@ static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwa }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0); auto opt_memory_format = r.memoryformatOptional(2); TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device"); @@ -432,7 +509,12 @@ static PyObject * THPVariable_byte(PyObject* self, PyObject* args, PyObject* kwa "byte(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Byte, opt_memory_format); END_HANDLE_TH_ERRORS @@ -444,7 +526,12 @@ static PyObject * THPVariable_char(PyObject* self, PyObject* args, PyObject* kwa "char(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Char, opt_memory_format); END_HANDLE_TH_ERRORS @@ -456,7 +543,12 @@ static PyObject * THPVariable_double(PyObject* self, PyObject* args, PyObject* k "double(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Double, opt_memory_format); END_HANDLE_TH_ERRORS @@ -468,7 +560,12 @@ static PyObject * THPVariable_float(PyObject* self, PyObject* args, PyObject* kw "float(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Float, opt_memory_format); END_HANDLE_TH_ERRORS @@ -480,7 +577,12 @@ static PyObject * THPVariable_half(PyObject* self, PyObject* args, PyObject* kwa "half(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Half, opt_memory_format); END_HANDLE_TH_ERRORS @@ -492,7 +594,12 @@ static PyObject * THPVariable_int(PyObject* self, PyObject* args, PyObject* kwar "int(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Int, opt_memory_format); END_HANDLE_TH_ERRORS @@ -504,7 +611,12 @@ static PyObject * THPVariable_long(PyObject* self, PyObject* args, PyObject* kwa "long(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Long, opt_memory_format); END_HANDLE_TH_ERRORS @@ -516,7 +628,12 @@ static PyObject * THPVariable_short(PyObject* self, PyObject* args, PyObject* kw "short(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Short, opt_memory_format); END_HANDLE_TH_ERRORS @@ -528,7 +645,12 @@ static PyObject * THPVariable_bool(PyObject* self, PyObject* args, PyObject* kwa "bool(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::Bool, opt_memory_format); END_HANDLE_TH_ERRORS @@ -540,7 +662,12 @@ static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* "bfloat16(*, MemoryFormat? memory_format=None)" }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto opt_memory_format = r.memoryformatOptional(0); return THPVariable_to_type(self, ScalarType::BFloat16, opt_memory_format); END_HANDLE_TH_ERRORS @@ -549,6 +676,9 @@ static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "element_size"); + } auto& self_ = reinterpret_cast(self)->cdata; return THPUtils_packInt64(self_.element_size()); END_HANDLE_TH_ERRORS @@ -559,6 +689,9 @@ static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "numpy"); + } jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; return torch::utils::tensor_to_numpy(self_); @@ -569,6 +702,10 @@ static PyObject * THPVariable_numpy(PyObject* self, PyObject* arg) static PyObject * THPVariable_record_stream(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + auto args = py::make_tuple(py::handle(arg)); + return handle_torch_function(self, "record_stream", args.ptr()); + } #ifdef USE_CUDA auto& self_ = reinterpret_cast(self)->cdata; if (!THCPStream_Check(arg)) { @@ -590,7 +727,12 @@ static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyO }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto requires_grad = r.toBool(0); // should we throw if requires_grad is true? var.requires_grad = True throws here // but it's nice to let this be a no-op. @@ -617,7 +759,12 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO "is_contiguous(*, MemoryFormat memory_format=contiguous_format)", }); ParsedArgs<1> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self_, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self_, args, kwargs, PyObject_Type(self_), "torch.Tensor"); + } + auto memory_format = r.memoryformat(0); auto& self = reinterpret_cast(self_)->cdata; return wrap(dispatch_is_contiguous(self, memory_format)); @@ -628,6 +775,9 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO static PyObject * THPVariable_item(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "item"); + } jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); auto& self_ = reinterpret_cast(self)->cdata; if (self_.is_floating_point()) { @@ -650,7 +800,12 @@ static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwa static PythonArgParser parser({ "map_(Tensor other, PyObject* callable)" }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<2> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + Variable other = r.tensor(0); if (self_.requires_grad() || other.requires_grad()) { throw std::runtime_error( @@ -669,7 +824,12 @@ static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kw static PythonArgParser parser({ "map2_(Tensor x, Tensor y, PyObject* callable)" }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + Variable x = r.tensor(0); Variable y = r.tensor(1); if (self_.requires_grad() || x.requires_grad() || y.requires_grad()) { @@ -684,6 +844,9 @@ static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kw static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new"); + } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); @@ -693,6 +856,9 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new_ones"); + } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); return THPVariable_Wrap(torch::utils::new_ones(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); @@ -702,6 +868,9 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "new_tensor"); + } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); @@ -711,6 +880,9 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "storage"); + } auto& self_ = reinterpret_cast(self)->cdata; return createPyObject(self_.storage(), self_.dtype()); END_HANDLE_TH_ERRORS @@ -719,6 +891,9 @@ static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "storage_type"); + } auto& self_ = reinterpret_cast(self)->cdata; auto storage = THPObjectPtr(createPyObject(self_.storage(), self_.dtype())); auto storage_type = (PyObject*)Py_TYPE(storage); @@ -730,7 +905,17 @@ static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg) static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS - auto parsed = parse_to_conversion(args, kwargs, /*allow_copy*/ true); + static PythonArgParser parser({ + "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", + }); + ParsedArgs<5> parsed_args; + auto r = parser.parse(self, args, kwargs, parsed_args); + if (r.has_torch_function()) { + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + auto parsed = parse_to_conversion(r, /*allow_copy*/ true); auto& device = std::get<0>(parsed); auto& scalarType = std::get<1>(parsed); auto non_blocking = std::get<2>(parsed); @@ -762,6 +947,9 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + return handle_torch_function(self, "tolist"); + } jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); auto self_ = reinterpret_cast(self)->cdata; return torch::utils::tensor_to_list(self_); @@ -777,7 +965,12 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa }); auto& self_ = reinterpret_cast(self)->cdata; ParsedArgs<3> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse(self, args, kwargs, parsed_args); + + if(r.has_torch_function()){ + return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor"); + } + if (r.isNone(0)) { return THPUtils_packString(torch::utils::options_to_string(self_.options())); } @@ -822,6 +1015,14 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa ${py_methods} static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { + if (check_has_torch_function(self)) { + try { + return handle_torch_function(self, "__bool__"); + } + catch(const python_error&) { + return nullptr; + } + } jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW); return THPVariable_is_nonzero(self, args); } @@ -830,6 +1031,7 @@ static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { // Used to implement binary arithmetic operators template static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) { + PyObject* ret = Func(self, args, kwargs); if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) { PyErr_Clear(); diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 48a4dcf5b3456..4ab136cf9279e 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -525,6 +525,7 @@ libtorch_python_core_sources = [ "torch/csrc/utils/tensor_new.cpp", "torch/csrc/utils/tensor_numpy.cpp", "torch/csrc/utils/tensor_types.cpp", + "torch/csrc/utils/disable_torch_function.cpp", ] libtorch_python_distributed_sources = [ diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py index 8bb362d1c6ac1..b0cbf45b252b1 100644 --- a/torch/_lobpcg.py +++ b/torch/_lobpcg.py @@ -8,7 +8,7 @@ import torch from torch import Tensor from . import _linalg_utils as _utils -from ._overrides import has_torch_function, handle_torch_function +from .overrides import has_torch_function, handle_torch_function __all__ = ['lobpcg'] diff --git a/torch/_lowrank.py b/torch/_lowrank.py index e37c8e4c79bcd..8bb2dae8cfaab 100644 --- a/torch/_lowrank.py +++ b/torch/_lowrank.py @@ -8,7 +8,7 @@ import torch from torch import Tensor from . import _linalg_utils as _utils -from ._overrides import has_torch_function, handle_torch_function +from .overrides import has_torch_function, handle_torch_function def get_approximate_basis(A, # type: Tensor diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 0814029bafe62..493d38675ad44 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -578,6 +579,8 @@ static PyMethodDef TorchMethods[] = { {"_set_qengine", (PyCFunction)THPModule_setQEngine, METH_O, nullptr}, {"_supported_qengines", (PyCFunction)THPModule_supportedQEngines, METH_NOARGS, nullptr}, {"_is_xnnpack_enabled", (PyCFunction)THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr}, + {"_is_torch_function_enabled", (PyCFunction)THPModule_isEnabledTorchFunction, METH_NOARGS, nullptr}, + {"_disabled_torch_function_impl", (PyCFunction)THPModule_disable_torch_function, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr} }; @@ -802,11 +805,12 @@ Call this whenever a new thread is created in order to propagate values from THPDefaultCPUGenerator = (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator); // This reference is meant to be given away, so no need to incref here. ASSERT_TRUE(set_module_attr("default_generator", (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); - + ASSERT_TRUE(set_module_attr("DisableTorchFunction", (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false)); + torch::set_disabled_torch_function_impl(PyObject_GetAttrString(module, "_disabled_torch_function_impl")); + ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); #ifdef USE_NUMPY if (_import_array() < 0) return nullptr; #endif - return module; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 46a3bbd07c2c2..e72717731ad67 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -151,7 +151,7 @@ static PyObject* THPVariable_as_subclass(THPVariable* self, PyObject* args, PyOb "as_subclass(PyObject* cls)", }); ParsedArgs<1> parsed_args{}; - auto r = parser.parse(args, kwargs, parsed_args); + auto r = parser.parse((PyObject *) self, args, kwargs, parsed_args); PyObject* cls = r.pyobject(0); if (!PyType_Check(cls)) { throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name); @@ -193,6 +193,9 @@ typedef int (*setter)(PyObject *, PyObject *, void *); PyObject *THPVariable_get_T(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "T"); + } auto& var = self->cdata; return THPVariable_Wrap(var.numpy_T()); END_HANDLE_TH_ERRORS @@ -201,6 +204,9 @@ PyObject *THPVariable_get_T(THPVariable *self, void *unused) PyObject *THPVariable_get_cdata(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "_cdata"); + } auto& var = self->cdata; return PyLong_FromVoidPtr(var.unsafeGetTensorImpl()); END_HANDLE_TH_ERRORS @@ -209,6 +215,9 @@ PyObject *THPVariable_get_cdata(THPVariable *self, void *unused) PyObject *THPVariable_get_version(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "_version"); + } auto& var = self->cdata; return PyInt_FromLong(var._version()); END_HANDLE_TH_ERRORS @@ -217,6 +226,9 @@ PyObject *THPVariable_get_version(THPVariable *self, void *unused) PyObject *THPVariable_get_grad_fn(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "grad_fn"); + } auto& var = self->cdata; if (!var.grad_fn()) { Py_RETURN_NONE; @@ -228,6 +240,9 @@ PyObject *THPVariable_get_grad_fn(THPVariable *self, void *unused) static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "_grad_fn", obj); + } THPUtils_assertRet(-1, obj, "Deletion of _grad_fn not allowed. Detach tensor instead!"); THPUtils_assertRet(-1, obj == Py_None, "_grad_fn can be only set to None"); self->cdata.detach_(); @@ -238,6 +253,9 @@ static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj, void *unuse static PyObject *THPVariable_is_leaf(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_leaf"); + } return PyBool_FromLong(!self->cdata.grad_fn()); END_HANDLE_TH_ERRORS } @@ -245,6 +263,9 @@ static PyObject *THPVariable_is_leaf(THPVariable *self, void *unused) static PyObject * THPVariable_get_data(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "data"); + } auto var = self->cdata.variable_data(); return THPVariable_Wrap(var); END_HANDLE_TH_ERRORS @@ -253,6 +274,9 @@ static PyObject * THPVariable_get_data(THPVariable *self, void *unused) int THPVariable_set_data(THPVariable *self, PyObject *data, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "data", data); + } THPUtils_assertRet(-1, data, "Deleting tensor data is not allowed. Delete tensor instead!"); if (!THPVariable_Check(data)) { throw torch::TypeError("Variable data has to be a tensor, but got %s", Py_TYPE(data)->tp_name); @@ -266,6 +290,9 @@ int THPVariable_set_data(THPVariable *self, PyObject *data, void *unused) PyObject *THPVariable_get_grad(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "grad"); + } return THPVariable_Wrap(self->cdata.grad()); END_HANDLE_TH_ERRORS } @@ -273,6 +300,9 @@ PyObject *THPVariable_get_grad(THPVariable *self, void *unused) int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "grad", py_grad); + } auto& var = self->cdata; if (!py_grad || py_grad == Py_None) { var.mutable_grad().reset(); @@ -304,6 +334,14 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) PyObject *THPVariable_get_volatile(THPVariable *self, void *unused) { + if (check_has_torch_function((PyObject *)self)) { + try { + return handle_torch_function_getter(self, "volatile"); + } + catch (const python_error&) { + return nullptr; + } + } const char* msg = "volatile was removed (Variable.volatile is always False)"; PyErr_WarnEx(PyExc_UserWarning, msg, 1); Py_RETURN_FALSE; @@ -311,12 +349,23 @@ PyObject *THPVariable_get_volatile(THPVariable *self, void *unused) int THPVariable_set_volatile(THPVariable *self, PyObject *obj, void *unused) { + if (check_has_torch_function((PyObject *)self)) { + try { + return handle_torch_function_setter(self, "volatile", obj); + } + catch (const python_error&) { + return -1; + } + } return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); } PyObject *THPVariable_get_output_nr(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "output_nr"); + } const auto output_nr = static_cast(self->cdata.output_nr()); return PyInt_FromLong(output_nr); END_HANDLE_TH_ERRORS @@ -325,6 +374,9 @@ PyObject *THPVariable_get_output_nr(THPVariable *self, void *unused) PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "requires_grad"); + } return PyBool_FromLong(self->cdata.requires_grad()); END_HANDLE_TH_ERRORS } @@ -332,6 +384,9 @@ PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused) PyObject *THPVariable_get_ndim(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "ndim"); + } return PyInt_FromLong(self->cdata.dim()); END_HANDLE_TH_ERRORS } @@ -339,6 +394,9 @@ PyObject *THPVariable_get_ndim(THPVariable *self, void *unused) PyObject *THPVariable_get_names(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "names"); + } // The long-term plan is to return a list of (python) torch.Dimname. // However, for now, return a list of string. size_t size = self->cdata.dim(); @@ -370,6 +428,9 @@ PyObject *THPVariable_get_names(THPVariable *self, void *unused) int THPVariable_set_names(THPVariable *self, PyObject *names) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "names", names); + } auto& var = self->cdata; if (names == Py_None) { at::internal_set_names_inplace(var, at::nullopt); @@ -386,6 +447,9 @@ int THPVariable_set_names(THPVariable *self, PyObject *names) { int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "requires_grad", obj); + } THPUtils_assertRet(-1, obj && PyBool_Check(obj), "requires_grad must be a bool"); auto& var = self->cdata; auto requires_grad = (obj == Py_True); @@ -404,6 +468,14 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused PyObject *THPVariable_get_name(THPVariable* self, void *unused) { + if (check_has_torch_function((PyObject *)self)) { + try { + return handle_torch_function_getter(self, "name"); + } + catch (const python_error&) { + return nullptr; + } + } if (self->cdata.name() == "") Py_RETURN_NONE; return THPUtils_packString(self->cdata.name().c_str()); @@ -412,6 +484,9 @@ PyObject *THPVariable_get_name(THPVariable* self, void *unused) PyObject *THPVariable_get_backwards_hooks(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "_backward_hooks"); + } if (self->backward_hooks) { Py_INCREF(self->backward_hooks); return self->backward_hooks; @@ -423,6 +498,9 @@ PyObject *THPVariable_get_backwards_hooks(THPVariable *self, void *unused) int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_setter(self, "_backward_hooks", obj); + } THPUtils_assertRet(-1, obj, "Deletion of _backwards_hooks not allowed!"); if (obj == Py_None) { obj = nullptr; @@ -441,6 +519,9 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unus PyObject *THPVariable_get_base(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "_base"); + } if (self->cdata.is_view()) { return THPVariable_Wrap(self->cdata._base()); } @@ -451,6 +532,9 @@ PyObject *THPVariable_get_base(THPVariable *self, void *unused) PyObject *THPVariable_get_shape(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "shape"); + } return THPSize_New(self->cdata); END_HANDLE_TH_ERRORS } @@ -458,6 +542,9 @@ PyObject *THPVariable_get_shape(THPVariable *self, void *unused) PyObject *THPVariable_is_cuda(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_cuda"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_cuda()); END_HANDLE_TH_ERRORS @@ -466,6 +553,9 @@ PyObject *THPVariable_is_cuda(THPVariable *self, void *unused) PyObject *THPVariable_is_sparse(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_sparse"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_sparse()); END_HANDLE_TH_ERRORS @@ -474,6 +564,9 @@ PyObject *THPVariable_is_sparse(THPVariable *self, void *unused) PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_mkldnn"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_mkldnn()); END_HANDLE_TH_ERRORS @@ -482,6 +575,9 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_quantized"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_quantized()); END_HANDLE_TH_ERRORS @@ -490,6 +586,9 @@ PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) PyObject *THPVariable_is_meta(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_meta"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_meta()); END_HANDLE_TH_ERRORS @@ -498,6 +597,9 @@ PyObject *THPVariable_is_meta(THPVariable *self, void *unused) PyObject *THPVariable_is_complex(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "is_complex"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(self_.is_complex()); END_HANDLE_TH_ERRORS @@ -506,6 +608,9 @@ PyObject *THPVariable_is_complex(THPVariable *self, void *unused) static PyObject *THPVariable_dtype(THPVariable *self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "dtype"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(torch::getTHPDtype(self_.scalar_type())); END_HANDLE_TH_ERRORS @@ -513,6 +618,9 @@ static PyObject *THPVariable_dtype(THPVariable *self, void *unused) static PyObject * THPVariable_layout(THPVariable* self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "layout"); + } auto& self_ = self->cdata; return torch::autograd::utils::wrap(torch::getTHPLayout(self_.layout())); END_HANDLE_TH_ERRORS @@ -520,6 +628,9 @@ static PyObject * THPVariable_layout(THPVariable* self, void *unused) { static PyObject * THPVariable_device(THPVariable* self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "device"); + } return THPDevice_New(self->cdata.device()); END_HANDLE_TH_ERRORS } @@ -527,6 +638,9 @@ static PyObject * THPVariable_device(THPVariable* self, void *unused) { PyObject *THPVariable_get_real(THPVariable* self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "real"); + } auto& self_ = self->cdata; auto real = at::real(self_); return THPVariable_Wrap(real); @@ -536,6 +650,9 @@ PyObject *THPVariable_get_real(THPVariable* self, void *unused) PyObject *THPVariable_get_imag(THPVariable* self, void *unused) { HANDLE_TH_ERRORS + if (check_has_torch_function((PyObject *)self)) { + return handle_torch_function_getter(self, "imag"); + } auto& self_ = self->cdata; auto imag = at::imag(self_); return THPVariable_Wrap(imag); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 470e7339a4a95..4b38d924c91b8 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -31,6 +32,14 @@ namespace torch { namespace autograd { Py_ssize_t THPVariable_length(PyObject* self) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + py::object ret = py::reinterpret_steal(handle_torch_function(self, "__len__")); + Py_ssize_t length = PyLong_AsSsize_t(ret.ptr()); + if (PyErr_Occurred()) { + throw python_error(); + } + return length; + } auto& self_ = reinterpret_cast(self)->cdata; if (self_.dim() == 0) { return 0; @@ -258,6 +267,10 @@ static inline THPObjectPtr wrapTuple(PyObject* index) { // indexing is needed, it calls C++ `at::indexing::dispatch_index`. PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { HANDLE_TH_ERRORS + if (check_has_torch_function(self)) { + py::tuple args_ = py::make_tuple(py::handle(index)); + return handle_torch_function(self, "__getitem__", args_.ptr()); + } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); @@ -331,6 +344,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { if (py_value == nullptr) { throw TypeError("Tensor does not support deleting items"); } + if (check_has_torch_function(self)) { + py::tuple args_ = py::make_tuple(py::handle(index), py::handle(py_value)); + py::object ret = py::reinterpret_steal(handle_torch_function(self, "__setitem__", args_.ptr())); + return 0; + } auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); diff --git a/torch/csrc/autograd/utils/python_arg_parsing.h b/torch/csrc/autograd/utils/python_arg_parsing.h index ef91db9d1136b..603c3db073ae1 100644 --- a/torch/csrc/autograd/utils/python_arg_parsing.h +++ b/torch/csrc/autograd/utils/python_arg_parsing.h @@ -10,14 +10,7 @@ namespace torch { namespace autograd { namespace utils { // The parameter allow_copy is to accept copy for Tensor.to (and by proxy // PackedSequences.to) but not nn.Module.to. inline std::tuple, c10::optional, bool, bool, c10::optional> - parse_to_conversion(PyObject *args, PyObject *kwargs, bool allow_copy) { - static PythonArgParser parser({ - "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", - "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", - "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)", - }); - ParsedArgs<5> parsed_args; - auto r = parser.parse(args, kwargs, parsed_args); + parse_to_conversion(PythonArgs& r, bool allow_copy) { if (r.idx == 0) { if (!allow_copy && !r.isNone(3)) throw std::runtime_error(".to() does not accept copy argument"); diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp new file mode 100644 index 0000000000000..ff7478ac5f39d --- /dev/null +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -0,0 +1,127 @@ +#include +#include +#include + +namespace torch { + static thread_local bool enable_torch_function = true; + PyObject* disabled_torch_function = nullptr; + + bool torch_function_enabled() { + return enable_torch_function; + } + + PyObject* disabled_torch_function_impl() { + return disabled_torch_function; + } + + void set_disabled_torch_function_impl(PyObject* value) { + disabled_torch_function = value; + } +} + +typedef struct { + PyObject_HEAD + /* Type-specific fields go here. */ + bool old_state; +} DisableTorchFunction; + +PyObject* DisableTorchFunction__enter(PyObject* self, PyObject *unused) { + ((DisableTorchFunction*)self)->old_state = torch::enable_torch_function; + torch::enable_torch_function = false; + Py_RETURN_NONE; +} + +PyObject* DisableTorchFunction__exit(PyObject* self, PyObject *unused) { + torch::enable_torch_function = ((DisableTorchFunction*)self)->old_state; + Py_RETURN_NONE; +} + +PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused) { + if (torch::enable_torch_function) { + Py_RETURN_TRUE; + } else + { + Py_RETURN_FALSE; + } +} + +static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT + {"__enter__", (PyCFunction)DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", (PyCFunction)DisableTorchFunction__exit, METH_VARARGS, nullptr}, + {nullptr, nullptr, 0, nullptr} +}; + +PyTypeObject DisableTorchFunctionType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch._C.DisableTorchFunction", /* tp_name */ + sizeof(DisableTorchFunction), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + DisableTorchFunction_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + PyType_GenericAlloc, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ +}; + +PyObject* THPModule_DisableTorchFunctionType() { + if (PyType_Ready(&DisableTorchFunctionType) < 0) { + return nullptr; + } + + return (PyObject *)(&DisableTorchFunctionType); +} + +PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) { + HANDLE_TH_ERRORS + PyObject *func=nullptr, *types=nullptr, *args=nullptr, *kwargs=nullptr; + if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) { + return nullptr; + } + py::tuple py_args; + if (args == nullptr) { + py_args = py::make_tuple(); + } + else { + py_args = py::reinterpret_borrow(args); + } + + // These are all C-API calls so no exceptions will be raised + // and therefore no need for RAII approach to storing + // the old value. + bool old_value = torch::enable_torch_function; + torch::enable_torch_function = false; + // kwargs can safely be nullptr here. + PyObject *result = PyObject_Call(func, py_args.ptr(), kwargs); + torch::enable_torch_function = old_value; + return result; + END_HANDLE_TH_ERRORS +} diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h new file mode 100644 index 0000000000000..12166607f79c8 --- /dev/null +++ b/torch/csrc/utils/disable_torch_function.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace torch { + // Sometimes we don't want infinite recursion for subclasses, + // Or a way to achieve the old behaviour. + + // This is an internal utility, not exposed to users. + bool torch_function_enabled(); + PyObject* disabled_torch_function_impl(); + void set_disabled_torch_function_impl(PyObject* value); +} + +PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject *unused); +PyObject* THPModule_DisableTorchFunctionType(); +PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *args); \ No newline at end of file diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 4aad3e45f82a2..97833bec008c3 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -135,11 +135,70 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only) } } -auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* { +auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject* { + py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); + std::string module_name = "torch.Tensor." + property_name; + return handle_torch_function((PyObject *)self, "__get__", nullptr, torch_api.ptr(), module_name); +} + +auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int { + py::object torch_api = PyObject_FastGetAttrString(THPVariableClass, (char*)property_name.c_str()); + std::string module_name = "torch.Tensor." + property_name; + if (value != nullptr) + { + py::tuple args_ = py::make_tuple(py::handle(value)); + handle_torch_function((PyObject *)self, "__set__", args_.ptr(), torch_api.ptr(), module_name); + } + else { + handle_torch_function((PyObject *)self, "__delete__", nullptr, torch_api.ptr(), module_name); + } + return 0; +} + +// Combines self and args into one tuple. +auto combine_self_args(PyObject *self, PyObject *args) -> py::tuple { + if (args == nullptr) { + return py::make_tuple(py::handle(self)); + } + else if (self == nullptr) { + return py::reinterpret_borrow(args); + } + + auto py_args = py::reinterpret_borrow(args); + size_t n = py_args.size(); + auto args_ = py::tuple(n + 1); + args_[0] = py::handle(self); + for (size_t i = 0; i < n; i++) { + args_[i+1] = py_args[i]; + } + return args_; +} + +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args, PyObject* torch_api, const std::string& module_name) -> PyObject* { + py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str()); + TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); + py::tuple args_ = combine_self_args(self, args); + py::tuple py_types = py::make_tuple(py::handle(PyObject_Type(self))); + py::object torch_function = PyObject_FastGetAttrString(self, "__torch_function__"); + py::object ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), NULL)); + if (ret.ptr() == nullptr) { + // if an exception occurred in a user's implementation of + // __torch_function__, throw it + throw python_error(); + } + if (ret.ptr() == Py_NotImplemented) { + std::string error_msg = "no implementation found for " + module_name + "." + func_name + "' on types that implement __torch_function__: [" + self->ob_type->tp_name + "]"; + PyErr_SetString(PyExc_TypeError, error_msg.c_str()); + throw python_error(); + } + return ret.release().ptr(); +} + +auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* { py::object torch_api_function = PyObject_FastGetAttrString(torch_api, (char*)r.get_func_name().c_str()); TORCH_INTERNAL_ASSERT(torch_api_function.ptr() != nullptr, "torch API function must exist"); py::object ret; - + py::tuple args_ = combine_self_args(self, args); // overloaded_args already all have unique types std::vector overloaded_types; overloaded_types.reserve(r.signature.overloaded_args.size()); @@ -150,7 +209,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb for (auto &arg : r.signature.overloaded_args) { py::object torch_function = PyObject_FastGetAttrString(arg.ptr(), "__torch_function__"); - ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args, kwargs, NULL)); + ret = py::reinterpret_steal(PyObject_CallFunctionObjArgs(torch_function.ptr(), torch_api_function.ptr(), py_types.ptr(), args_.ptr(), kwargs, NULL)); if (ret.ptr() != Py_NotImplemented) { // Return the reference to the result. This also covers the case where ret // is NULL and __torch_function__ raised an exception, which we throw below @@ -159,7 +218,7 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb } if (ret.ptr() == nullptr) { // if an exception occurred in a user's implementation of - // __array_function__, throw it + // __torch_function__, throw it throw python_error(); } else if (ret.ptr() == Py_NotImplemented) { @@ -184,6 +243,11 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb return ret.release().ptr(); } +auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject* +{ + return handle_torch_function(r, nullptr, args, kwargs, torch_api, module_name); +} + /* * obj has a __torch_function__ implementation and may either be a * subclass of Tensor or a Tensor-like duck type. We may need to @@ -653,9 +717,9 @@ static void extra_kwargs(FunctionSignature& signature, PyObject* kwargs, ssize_t throw TypeError("invalid keyword arguments"); } -bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], +bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], // NOLINT bool raise_exception) { - auto nargs = PyTuple_GET_SIZE(args); + auto nargs = args ? PyTuple_GET_SIZE(args) : 0; ssize_t remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0; ssize_t arg_pos = 0; bool allow_varargs_intlist = false; @@ -679,6 +743,9 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[], } int i = 0; + if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) { + append_overloaded_arg(this->overloaded_args, self); + } for (auto& param : params) { PyObject* obj = nullptr; bool is_kwd = false; @@ -798,25 +865,25 @@ void PythonArgParser::check_deprecated(const FunctionSignature & signature) { } } -PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { +PythonArgs PythonArgParser::raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT if (signatures_.size() == 1) { auto& signature = signatures_[0]; - signature.parse(args, kwargs, parsed_args, true); + signature.parse(self, args, kwargs, parsed_args, true); check_deprecated(signature); return PythonArgs(traceable, signature, parsed_args); } for (auto& signature : signatures_) { - if (signature.parse(args, kwargs, parsed_args, false)) { + if (signature.parse(self, args, kwargs, parsed_args, false)) { check_deprecated(signature); return PythonArgs(traceable, signature, parsed_args); } } - print_error(args, kwargs, parsed_args); + print_error(self, args, kwargs, parsed_args); } -void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { +void PythonArgParser::print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { // NOLINT auto num_args = PyTuple_GET_SIZE(args) + (kwargs ? PyDict_Size(kwargs) : 0); std::vector plausible_idxs; ssize_t i = 0; @@ -829,7 +896,7 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa if (plausible_idxs.size() == 1) { auto& signature = signatures_[plausible_idxs[0]]; - signature.parse(args, kwargs, parsed_args, true); + signature.parse(self, args, kwargs, parsed_args, true); } auto options = get_signatures(); diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 0aef4dc0c6e62..e3ab5568f204f 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -60,6 +60,7 @@ #include #include #include +#include #include #include @@ -96,17 +97,22 @@ struct PythonArgParser { explicit PythonArgParser(std::vector fmts, bool traceable=false); // meant only for `torch` functions. + template + inline PythonArgs parse(PyObject* self, PyObject* args, PyObject* kwargs, ParsedArgs& dst); + template inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst); + inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst); + // Formatted strings of non-hidden signatures std::vector get_signatures() const; private: [[noreturn]] - void print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); + void print_error(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); void check_deprecated(const FunctionSignature & signature); - PythonArgs raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); + PythonArgs raw_parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); std::vector signatures_; std::string function_name; @@ -117,7 +123,7 @@ struct PythonArgParser { struct PYBIND11_EXPORT FunctionSignature { explicit FunctionSignature(const std::string& fmt, int index); - bool parse(PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); + bool parse(PyObject* self, PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); std::string toString() const; @@ -130,6 +136,7 @@ struct PYBIND11_EXPORT FunctionSignature { int index; bool hidden; bool deprecated; + bool disable_torch_function; }; struct PythonArgs { @@ -227,12 +234,21 @@ struct FunctionParameter { }; template -inline PythonArgs PythonArgParser::parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst) { +inline PythonArgs PythonArgParser::parse(PyObject* self, PyObject* args, PyObject* kwargs, ParsedArgs& dst) { if (N < max_args) { throw ValueError("PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N); } - return raw_parse(args, kwargs, dst.args); + return raw_parse(self, args, kwargs, dst.args); +} + +template +inline PythonArgs PythonArgParser::parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst) { + return parse(nullptr, args, kwargs, dst); +} + +inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) { + return parse(self, nullptr, nullptr, dst); } inline bool PythonArgs::has_torch_function(){ @@ -683,10 +699,10 @@ static bool _is_basic_python_type(PyTypeObject *tp) static py::object PyTorch_LookupSpecial(PyObject *obj, char* name) { - PyTypeObject *tp = Py_TYPE(obj); if (THPVariable_CheckExact(obj)) { return py::object(); } + PyTypeObject *tp = Py_TYPE(obj); if (_is_basic_python_type(tp)) { return py::object(); } @@ -704,8 +720,11 @@ static py::object PyTorch_LookupSpecial(PyObject *obj, char* name) */ static auto check_has_torch_function(PyObject* obj) -> bool { + if (!torch_function_enabled()) { + return false; + } py::object method = PyTorch_LookupSpecial(obj, "__torch_function__"); - if(method.ptr() != nullptr){ + if(method.ptr() != nullptr && method.ptr() != disabled_torch_function_impl()){ return true; } return false; @@ -751,7 +770,19 @@ static auto check_has_torch_function(PyObject* obj) -> bool * 'torch_api' is a reference to a python torch API namespace. * */ +// Used for Tensor methods with arguments. +auto handle_torch_function(PythonArgs &r, PyObject* self, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*; +// Used fpr functions. auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyObject* torch_api, const char* module_name) -> PyObject*; +// Used for functions that accept no keyword arguments and have no argument parsing +auto handle_torch_function(PyObject* self, const std::string& func_name, PyObject* args=nullptr, PyObject* torch_api=THPVariableClass, const std::string& module_name="torch.Tensor") -> PyObject*; + +// Used for getters of Tensor properties +auto handle_torch_function_getter(THPVariable* self, const std::string& property_name) -> PyObject*; + +// Used for setters of Tensor properties. +auto handle_torch_function_setter(THPVariable* self, const std::string& property_name, PyObject* value) -> int; + } // namespace torch diff --git a/torch/functional.py b/torch/functional.py index 2ec28e6e3e08a..3212c7d738655 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from ._lowrank import svd_lowrank, pca_lowrank -from ._overrides import has_torch_function, handle_torch_function +from .overrides import has_torch_function, handle_torch_function from ._jit_internal import boolean_dispatch, List from ._jit_internal import _overload as overload @@ -319,7 +319,6 @@ def einsum(equation, *operands): if not torch.jit.is_scripting(): if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): return handle_torch_function(einsum, operands, equation, *operands) - if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument operands = operands[0] diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 1971d9c40009f..5d4cf00af2e2e 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -12,7 +12,7 @@ from . import grad # noqa: F401 from torch import _VF from .._jit_internal import boolean_dispatch, List, Optional, _overload -from .._overrides import has_torch_function, handle_torch_function +from ..overrides import has_torch_function, handle_torch_function Tensor = torch.Tensor diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index e16bec0163b96..9749a70d024b2 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -1,4 +1,5 @@ import torch +from torch._C import _disabled_torch_function_impl from collections import OrderedDict @@ -19,7 +20,6 @@ class Parameter(torch.Tensor): requires_grad (bool, optional): if the parameter requires gradient. See :ref:`excluding-subgraphs` for more details. Default: `True` """ - def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.Tensor() @@ -42,3 +42,5 @@ def __reduce_ex__(self, proto): torch._utils._rebuild_parameter, (self.data, self.requires_grad, OrderedDict()) ) + + __torch_function__ = _disabled_torch_function_impl diff --git a/torch/_overrides.py b/torch/overrides.py similarity index 79% rename from torch/_overrides.py rename to torch/overrides.py index 026a7f2385ef3..1671f97c1387d 100644 --- a/torch/_overrides.py +++ b/torch/overrides.py @@ -21,20 +21,25 @@ import __future__ import collections -import torch +import functools import types +from typing import Dict, Set, List, Any, Callable, Iterable + +import torch +from torch._C import _is_torch_function_enabled, _disabled_torch_function_impl -def get_ignored_functions(): - """Return public functions that cannot be overrided by __torch_function__ +@functools.lru_cache(None) +def get_ignored_functions() -> Set[Callable]: + """Return public functions that cannot be overridden by __torch_function__ Returns ------- A tuple of functions that are publicly available in the torch API but cannot - be overrided with __torch_function__. Mostly this is because none of the + be overridden with __torch_function__. Mostly this is because none of the arguments of these functions are tensors or tensor-likes. - """ - return ( + Tensor = torch.Tensor + return { torch.typename, torch.is_tensor, torch.is_storage, @@ -150,10 +155,33 @@ def get_ignored_functions(): torch.is_vulkan_available, torch.is_deterministic, torch.set_deterministic, - torch.unify_type_list - ) + torch.unify_type_list, + Tensor.__delitem__, + Tensor.__dir__, + Tensor.__getattribute__, + Tensor.__init__, + Tensor.__init_subclass__, + Tensor.__delattr__, + Tensor.__setattr__, + Tensor.__torch_function__, + Tensor.__new__, + Tensor.__class__, + Tensor.__subclasshook__, + Tensor.as_subclass, + Tensor.reinforce, + Tensor.new, + Tensor.new_tensor, + Tensor.new_empty, + Tensor.new_zeros, + Tensor.new_ones, + Tensor.new_full, + Tensor._make_subclass, + Tensor.stride, + Tensor.unflatten, + } -def get_testing_overrides(): +@functools.lru_cache(None) +def get_testing_overrides() -> Dict[Callable, Callable]: """Return a dict containing dummy overrides for all overridable functions Returns @@ -162,7 +190,6 @@ def get_testing_overrides(): lambda functions that have the same signature as the real function and unconditionally return -1. These lambda functions are useful for testing API coverage for a type that defines __torch_function__. - """ # Every function in the PyTorch API that can be overriden needs an entry # in this dict. @@ -171,7 +198,8 @@ def get_testing_overrides(): # the lambda function procedurally but that is blocked by generating # function signatures for native kernels that can be consumed by inspect. # See Issue #28233. - return { + Tensor = torch.Tensor + ret = { torch.abs: lambda input, out=None: -1, torch.absolute: lambda input, out=None: -1, torch.adaptive_avg_pool1d: lambda input, output_size: -1, @@ -688,9 +716,203 @@ def get_testing_overrides(): torch.var_mean: lambda input, dim=None: -1, torch.where: lambda condition, x=None, y=None: -1, torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, + Tensor.__floordiv__: lambda self, other: -1, + Tensor.__rfloordiv__: lambda self, other: -1, + Tensor.__ifloordiv__: lambda self, other: -1, + Tensor.__truediv__: lambda self, other: -1, + Tensor.__rtruediv__: lambda self, other: -1, + Tensor.__itruediv__: lambda self, other: -1, + Tensor.__lshift__: lambda self, other: -1, + Tensor.__ilshift__: lambda self, other: -1, + Tensor.__rshift__: lambda self, other: -1, + Tensor.__irshift__: lambda self, other: -1, + Tensor.__float__: lambda self: -1, + Tensor.__array__: lambda self, dtype: -1, + Tensor.__bool__: lambda self: -1, + Tensor.__contains__: lambda self, other: -1, + Tensor.__neg__: lambda self: -1, + Tensor.__invert__: lambda self: -1, + Tensor.__mod__: lambda self, other: -1, + Tensor.__array_wrap__: lambda self, array: -1, + Tensor.__getitem__: lambda self, idx: -1, + Tensor.__deepcopy__: lambda self, memo: -1, + Tensor.__iter__: lambda self: -1, + Tensor.__int__: lambda self: -1, + Tensor.__long__: lambda self: -1, + Tensor.__hash__: lambda self: -1, + Tensor.__index__: lambda self: -1, + Tensor.__len__: lambda self: -1, + Tensor.__format__: lambda self, format_spec: -1, + Tensor.__reduce_ex__: lambda self, proto: -1, + Tensor.__reversed__: lambda self: -1, + Tensor.__repr__: lambda self: -1, + Tensor.__setitem__: lambda self, k, v: -1, + Tensor.__setstate__: lambda self, d: -1, + Tensor.T.__get__: lambda self: -1, + Tensor._backward_hooks.__get__: lambda self: -1, + Tensor._base.__get__: lambda self: -1, + Tensor._cdata.__get__: lambda self: -1, + Tensor.grad.__get__: lambda self: -1, + Tensor._grad.__get__: lambda self: -1, + Tensor._grad_fn.__get__: lambda self: -1, + Tensor.grad_fn.__get__: lambda self: -1, + Tensor._version.__get__: lambda self: -1, + Tensor.data.__get__: lambda self: -1, + Tensor.device.__get__: lambda self: -1, + Tensor.dtype.__get__: lambda self: -1, + Tensor.is_cuda.__get__: lambda self: -1, + Tensor.is_leaf.__get__: lambda self: -1, + Tensor.is_meta.__get__: lambda self: -1, + Tensor.is_mkldnn.__get__: lambda self: -1, + Tensor.is_quantized.__get__: lambda self: -1, + Tensor.is_sparse.__get__: lambda self: -1, + Tensor.layout.__get__: lambda self: -1, + Tensor.name.__get__: lambda self: -1, + Tensor.names.__get__: lambda self: -1, + Tensor.ndim.__get__: lambda self: -1, + Tensor.output_nr.__get__: lambda self: -1, + Tensor.requires_grad.__get__: lambda self: -1, + Tensor.shape.__get__: lambda self: -1, + Tensor.volatile.__get__: lambda self: -1, + Tensor.real.__get__: lambda self: -1, + Tensor.imag.__get__: lambda self: -1, + Tensor.__cuda_array_interface__.__get__: lambda self: -1, + Tensor.type: lambda self, dtype=None, non_blocking=False, **kwargs: -1, + Tensor._coalesced_: lambda self: -1, + Tensor._dimI: lambda self: -1, + Tensor._dimV: lambda self: -1, + Tensor._indices: lambda self: -1, + Tensor._is_view: lambda self: -1, + Tensor._nnz: lambda self: -1, + Tensor._update_names: lambda self, names, inplace: -1, + Tensor._values: lambda self: -1, + Tensor.align_as: lambda self, other: -1, + Tensor.align_to: lambda self, order, ellipsis_idx: -1, + Tensor.apply_: lambda self, callable: -1, + Tensor.as_strided: lambda self, size, stride: -1, + Tensor.as_strided_: lambda self, size, stride: -1, + Tensor.backward: lambda self, gradient=None, retain_graph=None, create_graph=False: -1, + Tensor.bfloat16: lambda self, memory_format=torch.preserve_format: -1, + Tensor.bool: lambda self, memory_format=torch.preserve_format: -1, + Tensor.byte: lambda self, memory_format=torch.preserve_format: -1, + Tensor.char: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cauchy_: lambda self, median=0, sigma=1, *, generator=None: -1, + Tensor.coalesce: lambda self: -1, + Tensor._coalesced_: lambda self, coalesced: -1, + Tensor.contiguous: lambda self, memory_format=torch.contiguous_format: -1, + Tensor.copy_: lambda self, src, non_blocking=False: -1, + Tensor.cpu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.cuda: lambda self, memory_format=torch.preserve_format: -1, + Tensor.data_ptr: lambda self: -1, + Tensor.dense_dim: lambda self: -1, + Tensor.dim: lambda self: -1, + Tensor.double: lambda self, memory_format=torch.preserve_format: -1, + Tensor.element_size: lambda self: -1, + Tensor.expand: lambda self, size: -1, + Tensor.expand_as: lambda self, other: -1, + Tensor.exponential_: lambda self, lambd=1, *, generator=None: -1, + Tensor.fill_: lambda self, value: -1, + Tensor.fill_diagonal_: lambda self, value: -1, + Tensor.float: lambda self, memory_format=torch.preserve_format: -1, + Tensor.geometric_: lambda self, p, *, generator=None: -1, + Tensor.get_device: lambda self: -1, + Tensor.half: lambda self, memory_format=torch.preserve_format: -1, + Tensor.has_names: lambda self: -1, + Tensor.indices: lambda self: -1, + Tensor.int: lambda self, memory_format=torch.preserve_format: -1, + Tensor.is_coalesced: lambda self: -1, + Tensor.is_contiguous: lambda self: -1, + Tensor.is_pinned: lambda self: -1, + Tensor.is_set_to: lambda self, tensor: -1, + Tensor.is_shared: lambda self: -1, + Tensor.item: lambda self: -1, + Tensor.log_normal_: lambda self, mean=1, std=2, *, generator=None: -1, + Tensor.log_softmax: lambda self, dim: -1, + Tensor.long: lambda self, memory_format=torch.preserve_format: -1, + Tensor.map_: lambda self, tensor, callable: -1, + Tensor.map2_: lambda self, x, y, callable: -1, + Tensor.mm: lambda self, mat2: -1, + Tensor.narrow_copy: lambda self, dimension, start, length: -1, + Tensor.ndimension: lambda self: -1, + Tensor.nelement: lambda self: -1, + Tensor.normal_: lambda self: -1, + Tensor.numpy: lambda self: -1, + Tensor.permute: lambda self, dim: -1, + Tensor.pin_memory: lambda self: -1, + Tensor.put_: lambda self, indices, tensor, accumulate=False: -1, + Tensor.qscheme: lambda self: -1, + Tensor.random_: lambda self, from_=0, to=None, *, generator=None: -1, + Tensor.record_stream: lambda self, stream: -1, + Tensor.refine_names: lambda self, names: -1, + Tensor.register_hook: lambda self, hook: -1, + Tensor.rename: lambda self, name: -1, + Tensor.repeat: lambda self, *size: -1, + Tensor.requires_grad_: lambda self, requires_grad=True: -1, + Tensor.reshape_as: lambda self, other: -1, + Tensor.resize: lambda self, *size: -1, + Tensor.resize_: lambda self, size: -1, + Tensor.resize_as: lambda self, other: -1, + Tensor.retain_grad: lambda self: -1, + Tensor.set_: lambda self, source=None, storage_offset=0, size=None, stride=None: -1, + Tensor.share_memory_: lambda self: -1, + Tensor.short: lambda self, memory_format=torch.preserve_format: -1, + Tensor.size: lambda self: -1, + Tensor.sparse_dim: lambda self: -1, + Tensor.sparse_mask: lambda self, mask: -1, + Tensor.sparse_resize_: lambda self, size1, size2, dense_dim: -1, + Tensor.sparse_resize_and_clear_: lambda self, size1, size2, dense_dim: -1, + Tensor.sspaddmm: lambda self, mat1, mat2, beta=1, alpha=1, out=None: -1, + Tensor.storage: lambda self: -1, + Tensor.storage_offset: lambda self: -1, + Tensor.storage_type: lambda self: -1, + Tensor.sum_to_size: lambda self, size: -1, + Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1, + Tensor.to_dense: lambda self: -1, + Tensor.to_sparse: lambda self: -1, + Tensor.tolist: lambda self: -1, + Tensor.to_mkldnn: lambda self: -1, + Tensor.type_as: lambda self, other: -1, + Tensor.unfold: lambda self, dimension, size, step: -1, + Tensor.uniform_: lambda self, from_=0, to=1: -1, + Tensor.values: lambda self: -1, + Tensor.view: lambda self, shape: -1, + Tensor.view_as: lambda self, other: -1, + Tensor.zero_: lambda self: -1, } -def _get_overloaded_args(relevant_args): + ret2 = {} + ignored = get_ignored_functions() + + for k, v in ret.items(): + # Generate methods like __add__ and add_ by default from add + names = [ + k.__name__, # Default method + k.__name__ + "_", # Inplace variant + "__" + k.__name__ + "__", # Dunder method + "__i" + k.__name__ + "__", # Inplace dunder method + "__r" + k.__name__ + "__", # Reverse dunder method + ] + + if k.__name__.startswith("bitwise_"): + # bitwise_ have dunder methods of the form ____ + # And so on. + subname = k.__name__[len("bitwise_"):] + names.extend([ + "__" + subname + "__", + "__i" + subname + "__", + "__r" + subname + "__" + ]) + + for name in names: + func = getattr(Tensor, name, None) + if callable(func) and func not in ret and func not in ignored: + ret2[func] = v + + ret.update(ret2) + return ret + + +def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: """Returns a list of arguments on which to call __torch_function__. Checks arguments in relevant_args for __torch_function__ implementations, @@ -754,7 +976,7 @@ def _get_overloaded_args(relevant_args): def handle_torch_function( - public_api, relevant_args, *args, **kwargs): + public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any: """Implement a function with checks for __torch_function__ overrides. See torch::autograd::handle_torch_function for the equivalent of this @@ -802,7 +1024,7 @@ def handle_torch_function( '__torch_function__: {}' .format(func_name, list(map(type, overloaded_args)))) -def has_torch_function(relevant_args): +def has_torch_function(relevant_args: Iterable[Any]) -> bool: """Check for __torch_function__ implementations in the elements of an iterable Arguments @@ -815,41 +1037,64 @@ def has_torch_function(relevant_args): True if any of the elements of relevant_args have __torch_function__ implementations, False otherwise. """ - return any(hasattr(a, '__torch_function__') for a in relevant_args) + return _is_torch_function_enabled() and any( + type(a) is not torch.Tensor and + getattr(a, '__torch_function__', _disabled_torch_function_impl) + is not _disabled_torch_function_impl + for a in relevant_args + ) -def get_overridable_functions(): +@functools.lru_cache(None) +def get_overridable_functions() -> Dict[Any, List[Callable]]: """List functions that are overridable via __torch_function__ Returns ------- A dictionary that maps namespaces that contain overridable functions - to functions in that namespace that can be overrided. - + to functions in that namespace that can be overridden. """ overridable_funcs = collections.defaultdict(list) tested_namespaces = [ (torch, torch.__all__ + dir(torch._C._VariableFunctions)), (torch.functional, torch.functional.__all__), (torch.nn.functional, dir(torch.nn.functional)), + (torch.Tensor, dir(torch.Tensor)) ] for namespace, ns_funcs in tested_namespaces: for func_name in ns_funcs: # ignore private functions or functions that are deleted in torch.__init__ - if func_name.startswith('_') or func_name == 'unique_dim': - continue - # ignore in-place operators - if func_name.endswith('_'): - continue - # only consider objects with lowercase names - if not func_name.islower(): - continue + if namespace is not torch.Tensor: + if func_name.startswith('_'): + continue + elif func_name.endswith('_'): + continue + elif not func_name[0].islower(): + continue + elif func_name == 'unique_dim': + continue + else: + func = getattr(namespace, func_name) + if getattr(object, func_name, None) == func: + continue + if func_name == '__weakref__': + continue func = getattr(namespace, func_name) + if namespace is torch.Tensor and getattr(object, func_name, None) == func: + continue # ignore re-exported modules if isinstance(func, types.ModuleType): continue # ignore __future__ imports if isinstance(func, __future__._Feature): continue + + if not callable(func) and hasattr(func, "__get__"): + overridable_funcs[func].append(func.__get__) + continue + + if not callable(func): + continue + # cannot be overriden by __torch_function__ if func in get_ignored_functions(): msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions " @@ -858,3 +1103,27 @@ def get_overridable_functions(): continue overridable_funcs[namespace].append(func) return overridable_funcs + +@functools.lru_cache(None) +def get_tensor_methods() -> Set[Callable]: + """ Returns a set of the overridable methods on ``torch.Tensor`` """ + overridable_funcs = get_overridable_functions() + methods = set(overridable_funcs[torch.Tensor]) + return methods + +def is_tensor_method_or_property(func: Callable) -> bool: + """ + Returns True if the function passed in is a handler for a + method or property belonging to ``torch.Tensor``, as passed + into ``__torch_function__``. + + .. note:: + For properties, their ``__get__`` method must be passed in. + + This may be needed, in particular, for the following reasons: + + 1. Methods/properties sometimes don't contain a `__module__` slot. + 2. They require that the first passed-in argument is an instance + of ``torch.Tensor``. + """ + return func in get_tensor_methods() or func.__name__ == "__get__" diff --git a/torch/tensor.py b/torch/tensor.py index d14860b372d8a..253ab15b55e1b 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -18,6 +18,9 @@ def _wrap_type_error_to_not_implemented(f): @functools.wraps(f, assigned=assigned) def wrapped(*args, **kwargs): + from torch.overrides import has_torch_function, handle_torch_function + if not all(type(t) is Tensor for t in args) and has_torch_function(args): + return handle_torch_function(wrapped, args, *args, **kwargs) try: return f(*args, **kwargs) except TypeError: @@ -34,6 +37,10 @@ def wrapped(*args, **kwargs): # otherwise, it will not show up in autocomplete. class Tensor(torch._C._TensorBase): def __deepcopy__(self, memo): + from torch.overrides import has_torch_function, handle_torch_function + relevant_args = (self,) + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__deepcopy__, relevant_args, self, memo) if not self.is_leaf: raise RuntimeError("Only Tensors created explicitly by the user " "(graph leaves) support the deepcopy protocol at the moment") @@ -70,6 +77,10 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__reduce_ex__, relevant_args, self, proto) check_serializing_named_tensor(self) # See Note [Don't serialize hooks] torch.utils.hooks.warn_if_has_hooks(self) @@ -132,6 +143,10 @@ def __reduce_ex__(self, proto): return (torch._utils._rebuild_tensor_v2, args) def __setstate__(self, state): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__setstate__, relevant_args, self, state) # Warning: this method is NOT called when you torch.load() a tensor; # that is managed by _rebuild_tensor_v2 if not self.is_leaf: @@ -149,6 +164,10 @@ def __setstate__(self, state): self.requires_grad, _, self._backward_hooks = state def __repr__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__repr__, relevant_args, self) # All strings are unicode in Python 3. return torch._tensor_str._str(self) @@ -182,6 +201,16 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False): be constructed, allowing to compute higher order derivative products. Defaults to ``False``. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.backward, + relevant_args, + self, + gradient=gradient, + retain_graph=retain_graph, + create_graph=create_graph) torch.autograd.backward(self, gradient, retain_graph, create_graph) def register_hook(self, hook): @@ -213,6 +242,10 @@ def register_hook(self, hook): >>> h.remove() # removes the hook """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.register_hook, relevant_args, self, hook) if not self.requires_grad: raise RuntimeError("cannot register a hook on a tensor that " "doesn't require gradient") @@ -278,6 +311,10 @@ def trim(str): def retain_grad(self): r"""Enables .grad attribute for non-leaf Tensors.""" + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.retain_grad, relevant_args, self) if not self.requires_grad: raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False") if self.is_leaf: # no-op for leaves @@ -306,6 +343,10 @@ def is_shared(self): This is always ``True`` for CUDA tensors. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.is_shared, relevant_args, self) return self.storage().is_shared() def share_memory_(self): @@ -314,11 +355,19 @@ def share_memory_(self): This is a no-op if the underlying storage is already in shared memory and for CUDA tensors. Tensors in shared memory cannot be resized. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.share_memory_, relevant_args, self) self.storage().share_memory_() return self def __reversed__(self): r"""Reverses the tensor along dimension 0.""" + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__reversed__, relevant_args, self) if self.dim() == 0: return self else: @@ -326,11 +375,19 @@ def __reversed__(self): def norm(self, p="fro", dim=None, keepdim=False, dtype=None): r"""See :func:`torch.norm`""" + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.norm, relevant_args, self, p=p, dim=dim, keepdim=keepdim, dtype=dtype) return torch.norm(self, p, dim, keepdim, dtype=dtype) def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.lu, relevant_args, self, pivot=pivot, get_infos=get_infos) LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos)) if get_infos: return LU, pivots, infos @@ -345,21 +402,44 @@ def stft(self, n_fft, hop_length=None, win_length=None, window=None, This function changed signature at version 0.4.1. Calling with the previous signature may cause error or return incorrect result. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.stft, relevant_args, self, n_fft, hop_length=hop_length, + win_length=win_length, window=window, center=center, pad_mode=pad_mode, normalized=normalized, + onesided=onesided + ) return torch.stft(self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) def istft(self, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=True, length=None): r"""See :func:`torch.istft`""" + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.istft, relevant_args, self, n_fft, hop_length=hop_length, win_length=win_length, + window=window, center=center, normalized=normalized, onesided=onesided, length=None + ) return torch.istft(self, n_fft, hop_length, win_length, window, center, normalized, onesided, length) def resize(self, *sizes): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.resize, relevant_args, self, *sizes) warnings.warn("non-inplace resize is deprecated") from torch.autograd._functions import Resize return Resize.apply(self, sizes) def resize_as(self, tensor): + relevant_args = (self, tensor) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(tensor) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.resize_as, relevant_args, self, tensor) warnings.warn("non-inplace resize_as is deprecated") from torch.autograd._functions import Resize return Resize.apply(self, tensor.size()) @@ -367,6 +447,10 @@ def resize_as(self, tensor): def split(self, split_size, dim=0): r"""See :func:`torch.split` """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.split, relevant_args, self, split_size, dim=dim) if isinstance(split_size, int): return super(Tensor, self).split(split_size, dim) elif isinstance(split_size, Tensor): @@ -383,6 +467,13 @@ def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=Non See :func:`torch.unique` """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.unique, relevant_args, self, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim + ) return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None): @@ -390,12 +481,27 @@ def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None See :func:`torch.unique_consecutive` """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function( + Tensor.unique_consecutive, relevant_args, self, return_inverse=return_inverse, + return_counts=return_counts, dim=dim + ) return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim) def __rsub__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__rsub__, relevant_args, self, other) return _C._VariableFunctions.rsub(self, other) def __rdiv__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__rdiv__, relevant_args, self, other) if self.dtype.is_floating_point or self.dtype.is_complex: return self.reciprocal() * other else: @@ -407,11 +513,19 @@ def __rdiv__(self, other): __pow__ = _C._TensorBase.pow def __format__(self, format_spec): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__format__, relevant_args, self, format_spec) if self.dim() == 0: return self.item().__format__(format_spec) return object.__format__(self, format_spec) def __ipow__(self, other): + relevant_args = (self, other) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__ipow__, relevant_args, self, other) return NotImplemented @_wrap_type_error_to_not_implemented @@ -441,11 +555,25 @@ def __rfloordiv__(self, other): __abs__ = _C._TensorBase.abs def __len__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__len__, relevant_args, self) if self.dim() == 0: raise TypeError("len() of a 0-d tensor") return self.shape[0] def __iter__(self): + # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a + # generator and don't eagerly perform all the indexes. This could + # save us work, and also helps keep trace ordering deterministic + # (e.g., if you zip(*hiddens), the eager map will force all the + # indexes of hiddens[0] before hiddens[1], while the generator + # map will interleave them.) + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__iter__, relevant_args, self) if self.dim() == 0: raise TypeError('iteration over a 0-d tensor') if torch._C._get_tracing_state(): @@ -456,9 +584,17 @@ def __iter__(self): return iter(self.unbind(0)) def __hash__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__hash__, relevant_args, self) return id(self) def __dir__(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__dir__, relevant_args, self) if self.is_quantized: warnings.warn('Only a small subset of methods are supported for quantized tensors.') tensor_methods = dir(self.__class__) @@ -476,6 +612,10 @@ def __dir__(self): __array_priority__ = 1000 # prefer Tensor ops over numpy ones def __array__(self, dtype=None): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__array__, relevant_args, self, dtype=dtype) if dtype is None: return self.numpy() else: @@ -484,6 +624,10 @@ def __array__(self, dtype=None): # Wrap Numpy array again in a suitable tensor when done, to support e.g. # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor` def __array_wrap__(self, array): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__array_wrap__, relevant_args, self, array=array) if array.dtype == bool: # Workaround, torch has no built-in bool tensor array = array.astype('uint8') @@ -496,6 +640,10 @@ def __contains__(self, element): element (Tensor or scalar): element to be checked for presence in current tensor" """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__contains__, relevant_args, self, element) if isinstance(element, (torch.Tensor, Number)): return (element == self).any().item() @@ -511,6 +659,10 @@ def __cuda_array_interface__(self): See: https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # raise AttributeError for unsupported tensors, so that # hasattr(cpu_tensor, "__cuda_array_interface__") is False. @@ -601,6 +753,10 @@ def refine_names(self, *names): The named tensor API is experimental and subject to change. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.refine_names, relevant_args, self, *names) names = resolve_ellipsis(names, self.names, 'refine_names') return super(Tensor, self).refine_names(names) @@ -640,6 +796,10 @@ def align_to(self, *names): The named tensor API is experimental and subject to change. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.align_to, relevant_args, self, *names) ellipsis_idx = single_ellipsis_index(names, 'align_to') if ellipsis_idx is None: return super(Tensor, self).align_to(names) @@ -665,12 +825,21 @@ def unflatten(self, dim, namedshape): The named tensor API is experimental and subject to change. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.unflatten, relevant_args, self, dim, namedshape) names, sizes = unzip_namedshape(namedshape) return super(Tensor, self).unflatten(dim, sizes, names) def rename_(self, *names, **rename_map): """In-place version of :meth:`~Tensor.rename`.""" + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.rename_, relevant_args, self, *names, **rename_map) + # Note [rename_ / rename API] # The Python API for these is different from the C++ API. In Python: # 1) tensor.rename(*names) takes a vararglist of names @@ -712,10 +881,20 @@ def rename(self, *names, **rename_map): The named tensor API is experimental and subject to change. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.rename, relevant_args, self, *names, **rename_map) + # See Note [rename_ / rename API] return update_names(self, names, rename_map, inplace=False) def _update_names(self, names, inplace): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor._update_names, relevant_args, self, names, inplace) + # See Note [rename_ / rename API] if inplace: return super(Tensor, self).rename_(names) @@ -730,6 +909,11 @@ def grad(self): The attribute will then contain the gradients computed and future calls to :func:`backward` will accumulate (add) gradients into it. """ + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.grad.__get__, relevant_args, self) + if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None: warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad " "attribute won't be populated during autograd.backward(). If you indeed want the gradient " @@ -740,10 +924,56 @@ def grad(self): @grad.setter def grad(self, new_grad): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) self._grad = new_grad @grad.deleter def grad(self): + relevant_args = (self,) + from torch.overrides import has_torch_function, handle_torch_function + if type(self) is not Tensor and has_torch_function(relevant_args): + return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) del self._grad + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + This __torch_function__ implementation wraps subclasses such that + methods called on subclasses return a subclass instance instead of + a ``torch.Tensor`` instance. + + One corollary to this is that you need coverage for torch.Tensor + methods if implementing __torch_function__ for subclasses. + + We recommend always calling ``super().__torch_function__`` as the base + case when doing the above. + + While not mandatory, we recommend making `__torch_function__` a classmethod. + """ + if kwargs is None: + kwargs = {} + + if not all(issubclass(cls, t) for t in types): + return NotImplemented + + with _C.DisableTorchFunction(): + ret = func(*args, **kwargs) + return _convert(ret, cls) + __module__ = 'torch' + + +def _convert(ret, cls): + if cls is Tensor: + return ret + + if isinstance(ret, Tensor): + ret = ret.as_subclass(cls) + + if isinstance(ret, tuple): + ret = tuple(_convert(r, cls) for r in ret) + + return ret