-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathattention_block.py
180 lines (137 loc) · 6.2 KB
/
attention_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MAB(nn.Module):
"""
Multi-Headed Attention Block (MAB).
Implements the multi-head attention mechanism, which is a key component of the Transformer architecture.
Attributes:
dim_V (int): Dimension of the value.
num_heads (int): Number of attention heads.
fc_q (nn.Linear): Fully connected layer for transforming the query input.
fc_k (nn.Linear): Fully connected layer for transforming the key input.
fc_v (nn.Linear): Fully connected layer for transforming the value input.
ln0 (nn.LayerNorm, optional): Layer normalization applied before the output linear layer.
ln1 (nn.LayerNorm, optional): Layer normalization applied after the output linear layer.
fc_o (nn.Linear): Output fully connected layer.
"""
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
"""
Initialize the MAB.
Args:
dim_Q (int): Dimension of the query input.
dim_K (int): Dimension of the key input.
dim_V (int): Dimension of the value input.
num_heads (int): Number of attention heads.
ln (bool, optional): Whether to use layer normalization. Default is False.
"""
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
"""
Forward pass for MAB.
Args:
Q (torch.Tensor): Query tensor of shape (batch_size, seq_length, dim_Q).
K (torch.Tensor): Key tensor of shape (batch_size, seq_length, dim_K).
Returns:
torch.Tensor: Output tensor after multi-head attention mechanism.
"""
# Transform inputs using fully connected layers
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
# Split tensors for multi-head attention
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
# Compute attention scores
A = torch.softmax(Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2)
# Compute weighted sum of values
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
# Apply layer normalization if enabled
O = O if getattr(self, "ln0", None) is None else self.ln0(O)
# Add a residual connection with a linear transformation and ReLU activation
O = O + F.relu(self.fc_o(O))
# Apply another layer normalization if enabled
O = O if getattr(self, "ln1", None) is None else self.ln1(O)
return O
class ISAB(nn.Module):
"""
Induced Set Attention Block (ISAB).
Implements a mechanism where the attention is not computed for every pair of points, but instead,
it's induced through a set of learnable inducing points, thereby reducing computational complexity.
Attributes:
I (nn.Parameter): A set of learnable inducing points.
mab0 (MAB): First Multi-Headed Attention Block.
mab1 (MAB): Second Multi-Headed Attention Block.
"""
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
"""
Initialize the ISAB.
Args:
dim_in (int): Dimension of the input.
dim_out (int): Desired dimension of the output.
num_heads (int): Number of attention heads for the MABs.
num_inds (int): Number of inducing points.
ln (bool, optional): Whether to use layer normalization in MABs. Default is False.
"""
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
# Define the two Multi-Headed Attention Blocks
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
"""
Forward pass for ISAB.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_length, dim_in).
Returns:
torch.Tensor: Output tensor after inducing set attention mechanism.
"""
# Compute the induced attention from the learnable inducing points to the input X
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
# Compute the attention from the input X to the induced attention H
return self.mab1(X, H)
class PMA(nn.Module):
"""
Pooling by Multi-Headed Attention (PMA).
Implements a pooling mechanism that uses a set of learnable seed vectors,
which are attended over by the input set, allowing for a form of set reduction.
Attributes:
S (nn.Parameter): A set of learnable seed vectors.
mab (MAB): Multi-Headed Attention Block used for the pooling.
"""
def __init__(self, dim, num_heads, num_seeds, ln=False):
"""
Initialize the PMA.
Args:
dim (int): Dimension of the input and the seed vectors.
num_heads (int): Number of attention heads for the MAB.
num_seeds (int): Number of seed vectors.
ln (bool, optional): Whether to use layer normalization in the MAB. Default is False.
"""
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
# Define the Multi-Headed Attention Block for pooling
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
"""
Forward pass for PMA.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_length, dim).
Returns:
torch.Tensor: Pooled output tensor after attending over the learnable seed vectors.
"""
# Attend over the learnable seed vectors using the input set X
return self.mab(self.S.repeat(X.size(0), 1, 1), X)