-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathinner_autograd_tensor.py
117 lines (101 loc) · 4.67 KB
/
inner_autograd_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn.functional
from base_tensor import BaseTensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_map
from utils import fill_defaults
from torch.overrides import enable_reentrant_dispatch
# This file describes how to use wrapper tensors (ala TrivialTensorViaComposition)
# to override autograd from __torch_dispatch__. Ordinarily,
# __torch_dispatch__ runs after autograd, so you have no way of overriding
# the autograd behavior (since it will be handled after you return). However,
# if we put the autograd tensor *inside* a wrapper tensor (which doesn't
# itself require gradients), we get a chance to interpose (in __torch_dispatch__)
# before you handle gradients on the inner element.
#
# Note that you can also use __torch_function__ instead to implement this
# functionality, so this is mostly a question of whether or not you want to
# target the public Python API, or the internal ATen operators API
# (torch.ops.aten).
class InnerAutogradTensor(BaseTensor):
@staticmethod
def __new__(cls, elem, *, requires_grad=None):
# Outer tensor's autograd is now disconnected from the inner
# tensors autograd...
return super().__new__(cls, elem, requires_grad=False)
def __init__(self, elem):
# ... but note that we save the inner tensor, so we can still
# do autograd on operations on the inside!
self.elem = elem
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, cls):
return t.elem
elif isinstance(t, torch.Tensor) and t.requires_grad:
# If any other argument at this level does require gradients
# it will not interact with our inner Tensor and thus this
# should fail.
raise RuntimeError("Bad mixup of autograd level")
else:
return t
def wrap(t):
# Micro-optimization: not necessary to rewrap if the output tensor
# doesn't require gradients
if (
isinstance(t, torch.Tensor)
and not isinstance(t, cls)
and t.requires_grad
):
return cls(t)
else:
return t
with enable_reentrant_dispatch():
# Override gradient behavior
if func == torch.ops.aten.embedding.default:
args = fill_defaults(args, 5, [-1, False, False])
weight, indices, padding_idx, scale_grad_by_freq, _sparse = map(
unwrap, args
)
assert not kwargs
# Force sparse gradients. We could have also done this by
# defining a custom autograd function.
return cls(func(weight, indices, padding_idx, scale_grad_by_freq, True))
return tree_map(
wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
)
class InnerAutogradTensorTest(TestCase):
def test_basic(self):
x = torch.randn(1, requires_grad=True)
y = InnerAutogradTensor(x)
self.assertFalse(y.requires_grad)
self.assertTrue(y.elem.requires_grad)
z = InnerAutogradTensor(x)
# Although y and z do not require grad, we are still able
# to differentiate
r = y + z
# Note we have to extract out the inner tensor (which requires_grad)
# to actually differentiate
r.sum().elem.backward()
self.assertEqual(x.grad, torch.tensor([2.0])) # two uses!
def test_embedding(self):
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
weights = torch.rand(10, 3, requires_grad=True)
embedding_matrix = InnerAutogradTensor(weights)
r = torch.nn.functional.embedding(input, embedding_matrix)
r.sum().elem.backward()
# Gradient is sparse even though we didn't ask for it in embedding!
self.assertTrue(weights.grad.is_sparse)
def test_mixing(self):
# Mixing behavior is confusing. Let's take a look
w1 = torch.randn(1, requires_grad=True)
w2 = torch.randn(1, requires_grad=True)
# Autograd doesn't "unwrap" variables, they still remember if they
# requires_grad; and in fact, inside __torch_dispatch__ it is willing
# to mix gradients between multiple levels. The current class does
# catch most of these though when it is looking at the different
# arguments
with self.assertRaisesRegex(RuntimeError, "Bad mixup of autograd level"):
x = InnerAutogradTensor(w1) + w2
if __name__ == "__main__":
run_tests()