-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtransformer.py
74 lines (61 loc) · 2.46 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
def split_last(x, shape):
shape = list(shape)
assert shape.count(-1) <= 1
if -1 in shape:
shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
return x.view(*x.size()[:-1], *shape)
class MultiHeadAttention(nn.Module):
def __init__(self, dim, num_heads, dropout):
super().__init__()
self.proj_q = nn.Linear(dim, dim)
self.proj_k = nn.Linear(dim, dim)
self.proj_v = nn.Linear(dim, dim)
self.drop = nn.Dropout(dropout)
self.n_heads = num_heads
self.scores = None
def forward(self, x, mask):
q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
if mask is not None:
mask = mask[:, None, None, :].float()
scores -= 10000.0 * (1.0 - mask)
scores = self.drop(F.softmax(scores, dim=-1))
h = (scores @ v).transpose(1, 2).contiguous()
h = h.view(*h.size()[:-2], -1)
self.scores = scores
return h
class PositionWiseFeedForward(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
self.fc1 = nn.Linear(dim, ff_dim)
self.fc2 = nn.Linear(ff_dim, dim)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class EncoderBlock(nn.Module):
def __init__(self, dim, num_heads, ff_dim, dropout):
super().__init__()
self.attn = MultiHeadAttention(dim, num_heads, dropout)
self.proj = nn.Linear(dim, dim)
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.pwff = PositionWiseFeedForward(dim, ff_dim)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask):
h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
x = x + h
h = self.drop(self.pwff(self.norm2(x)))
x = x + h
return x
class Transformer(nn.Module):
def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
super().__init__()
self.blocks = nn.ModuleList([
EncoderBlock(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])
def forward(self, x, mask=None):
for block in self.blocks:
x = block(x, mask)
return x