-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtopk.py
53 lines (46 loc) · 1.62 KB
/
topk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# a top-k that i am temporarily borrowing from https://math.stackexchange.com/questions/3280757/differentiable-top-k-function
import torch
from functorch import vmap, grad
from torch.autograd import Function
sigmoid = torch.sigmoid
sigmoid_grad = vmap(vmap(grad(sigmoid)))
class TopK(Function):
@staticmethod
def forward(ctx, xs, k):
ts, ps = _find_ts(xs, k)
ctx.save_for_backward(xs, ts)
return ps
@staticmethod
def backward(ctx, grad_output):
# Compute vjp, that is grad_output.T @ J.
xs, ts = ctx.saved_tensors
# Let v = sigmoid'(x + t)
v = sigmoid_grad(xs + ts)
s = v.sum(dim=1, keepdims=True)
# Jacobian is -vv.T/s + diag(v)
uv = grad_output * v
t1 = - uv.sum(dim=1, keepdims=True) * v / s
return t1 + uv, None
@torch.no_grad()
def _find_ts(xs, k):
b, n = xs.shape
assert 0 < k < n
# Lo should be small enough that all sigmoids are in the 0 area.
# Similarly Hi is large enough that all are in their 1 area.
lo = -xs.max(dim=1, keepdims=True).values - 10
hi = -xs.min(dim=1, keepdims=True).values + 10
for _ in range(64):
mid = (hi + lo)/2
mask = sigmoid(xs + mid).sum(dim=1) < k
lo[mask] = mid[mask]
hi[~mask] = mid[~mask]
ts = (lo + hi)/2
return ts, sigmoid(xs + ts)
# topk = TopK.apply
# xs = torch.randn(2, 3)
# ps = topk(xs, 2)
# print(xs, ps, ps.sum(dim=1))
# from torch.autograd import gradcheck
# input = torch.randn(20, 10, dtype=torch.double, requires_grad=True)
# for k in range(1, 10):
# print(k, gradcheck(topk, (input, k), eps=1e-6, atol=1e-4))