-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfast_attention.py
71 lines (54 loc) · 2.46 KB
/
fast_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import math
import torch
from torch.nn import Linear
from torch.nn import Module
from torch.nn import ModuleList
def apply_scaling(scale, x):
return torch.einsum("...n,...nd->...nd", scale, x)
def create_orf(d_k, m):
blocks = torch.randn(math.ceil(m / d_k), d_k, d_k)
blocks, _ = torch.qr(blocks)
scale = torch.randn(m, d_k).norm(dim=1)
return apply_scaling(scale, blocks.reshape(-1, d_k)[:m])
def apply_regular_feature_map(x, orf, epsilon=1e-6):
m, d_k = orf.shape
proj_x = x @ orf.T / math.pow(d_k, 1 / 4)
norm = (x ** 2).sum(dim=-1, keepdim=True) / (2 * math.sqrt(d_k))
return (torch.exp(proj_x - norm) + epsilon) / math.sqrt(m)
def apply_hyperbolic_feature_map(x, orf, epsilon=1e-6):
m, d_k = orf.shape
proj_x = x @ orf.T / math.pow(d_k, 1 / 4)
proj_x = torch.cat([proj_x, -proj_x], dim=-1)
norm = (x ** 2).sum(dim=-1, keepdim=True) / (2 * math.sqrt(d_k))
return (torch.exp(proj_x - norm) + epsilon) / math.sqrt(2 * m)
def fast_attention(query, key, value):
buffer = torch.cat([key.transpose(1, 2).bmm(value), key.sum(1).unsqueeze(-1)], dim=-1)
buffer = query.bmm(buffer)
return apply_scaling(1 / buffer[:, :, -1], buffer[:, :, :-1])
class FastSelfAttention(Module):
def __init__(self, d_model, h, m, use_hyperbolic):
super(FastSelfAttention, self).__init__()
self.h = h
self.linears = ModuleList([Linear(d_model, d_model) for _ in range(4)])
self.register_buffer("orf", create_orf(d_model // h, m), persistent=False)
self.apply_feature_map = apply_regular_feature_map
if use_hyperbolic:
self.apply_feature_map = apply_hyperbolic_feature_map
def redraw_orf(self):
m, d_k = self.orf.shape
orf = create_orf(d_k, m)
orf = orf.to(self.orf.device)
self.register_buffer("orf", orf, persistent=False)
def split_by_head(self, x, B, L):
return x.view(B, L, self.h, -1).permute(0, 2, 1, 3).reshape(B * self.h, L, -1)
def concat_by_head(self, x, B, L):
return x.reshape(B, self.h, L, -1).permute(0, 2, 1, 3).reshape(B, L, -1)
def forward(self, x):
B, L, _ = x.shape
query, key, value = (self.split_by_head(l(x), B, L) for l in self.linears[:3])
query = self.apply_feature_map(query, self.orf)
key = self.apply_feature_map(key, self.orf)
out = fast_attention(query, key, value)
out = self.concat_by_head(out, B, L)
out = self.linears[3](out)
return out