-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathattention_native.py
188 lines (148 loc) · 9.02 KB
/
attention_native.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from attention import Attention
@torch.no_grad()
def get_seqlen_and_mask(input_resolution, window_size):
attn_map = F.unfold(torch.ones([1, 1, input_resolution[0], input_resolution[1]]), window_size,
dilation=1, padding=(window_size // 2, window_size // 2), stride=1)
attn_local_length = attn_map.sum(-2).squeeze().unsqueeze(-1)
attn_mask = (attn_map.squeeze(0).permute(1, 0)) == 0
return attn_local_length, attn_mask
class AggregatedAttention(nn.Module):
def __init__(self, dim, input_resolution, num_heads=8, window_size=3, qkv_bias=True,
attn_drop=0., proj_drop=0., sr_ratio=1, fixed_pool_size=None, shared_head=0, routed_head=0):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.sr_ratio = sr_ratio
assert window_size % 2 == 1, "window size must be odd"
self.window_size = window_size
self.local_len = window_size ** 2
if fixed_pool_size is None:
self.pool_H, self.pool_W = input_resolution[0] // self.sr_ratio, input_resolution[1] // self.sr_ratio
else:
assert fixed_pool_size < min(input_resolution), \
f"The fixed_pool_size {fixed_pool_size} should be less than the shorter side of input resolution {input_resolution} to ensure pooling works correctly."
self.pool_H, self.pool_W = fixed_pool_size, fixed_pool_size
self.pool_len = self.pool_H * self.pool_W
self.unfold = nn.Unfold(kernel_size=window_size, padding=window_size // 2, stride=1)
self.temperature = nn.Parameter(
torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1)) # Initialize softplus(temperature) to 1/0.24.
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.query_embedding = nn.Parameter(
nn.init.trunc_normal_(torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02))
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# Components to generate pooled features.
self.pool = nn.AdaptiveAvgPool2d((self.pool_H, self.pool_W))
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
self.norm = nn.LayerNorm(dim)
self.act = nn.GELU()
# mlp to generate continuous relative position bias
self.cpb_fc1 = nn.Linear(2, 512, bias=True)
self.cpb_act = nn.ReLU(inplace=True)
self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)
# relative bias for local features
self.relative_pos_bias_local = nn.Parameter(
nn.init.trunc_normal_(torch.empty(num_heads, self.local_len), mean=0, std=0.0004))
# Generate padding_mask && sequnce length scale
local_seq_length, padding_mask = get_seqlen_and_mask(input_resolution, window_size)
self.register_buffer("seq_length_scale", torch.as_tensor(np.log(local_seq_length.numpy() + self.pool_len)),
persistent=False)
self.register_buffer("padding_mask", padding_mask, persistent=False)
# dynamic_local_bias:
self.learnable_tokens = nn.Parameter(
nn.init.trunc_normal_(torch.empty(num_heads, self.head_dim, self.local_len), mean=0, std=0.02))
self.learnable_bias = nn.Parameter(torch.zeros(num_heads, 1, self.local_len))
self.shared_head = shared_head
self.routed_head = routed_head
if self.routed_head > 0:
self.wg = torch.nn.Linear(dim, num_heads - shared_head, bias=False)
if self.shared_head > 0:
self.wg_0 = torch.nn.Linear(dim, 2, bias=False)
if self.shared_head > 1:
self.wg_1 = torch.nn.Linear(dim, shared_head, bias=False)
def forward(self, x, H, W, relative_pos_index, relative_coords_table):
B, N, C = x.shape
_x = x.reshape(B * N, C)
if self.routed_head > 0:
logits = self.wg(_x)
gates = F.softmax(logits, dim=1)
num_tokens, num_experts = gates.shape
_, indices = torch.topk(gates, k=self.routed_head, dim=1)
mask = F.one_hot(indices, num_classes=num_experts).sum(dim=1)
if self.training:
me = gates.mean(dim=0)
ce = mask.float().mean(dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
Attention.LOAD_BALANCING_LOSSES.append(l_aux)
routed_head_gates = gates * mask
denom_s = torch.sum(routed_head_gates, dim=1, keepdim=True)
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
routed_head_gates /= denom_s
routed_head_gates = routed_head_gates.reshape(B, N, -1) * self.routed_head
# Generate queries, normalize them with L2, add query embedding, and then magnify with sequence length scale and temperature.
# Use softplus function ensuring that the temperature is not lower than 0.
q_norm = F.normalize(self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3), dim=-1)
q_norm_scaled = (q_norm + self.query_embedding) * F.softplus(self.temperature) * self.seq_length_scale
# Generate unfolded keys and values and l2-normalize them
k_local, v_local = self.kv(x).chunk(2, dim=-1)
k_local = F.normalize(k_local.reshape(B, N, self.num_heads, self.head_dim), dim=-1).reshape(B, N, -1)
kv_local = torch.cat([k_local, v_local], dim=-1).permute(0, 2, 1).reshape(B, -1, H, W)
k_local, v_local = self.unfold(kv_local).reshape(
B, 2 * self.num_heads, self.head_dim, self.local_len, N).permute(0, 1, 4, 2, 3).chunk(2, dim=1)
# Compute local similarity
attn_local = ((q_norm_scaled.unsqueeze(-2) @ k_local).squeeze(-2) \
+ self.relative_pos_bias_local.unsqueeze(1)).masked_fill(self.padding_mask, float('-inf'))
# Generate pooled features
x_ = x.permute(0, 2, 1).reshape(B, -1, H, W).contiguous()
x_ = self.pool(self.act(self.sr(x_))).reshape(B, -1, self.pool_len).permute(0, 2, 1)
x_ = self.norm(x_)
# Generate pooled keys and values
kv_pool = self.kv(x_).reshape(B, self.pool_len, 2 * self.num_heads, self.head_dim).permute(0, 2, 1, 3)
k_pool, v_pool = kv_pool.chunk(2, dim=1)
# Use MLP to generate continuous relative positional bias for pooled features.
pool_bias = self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table))).transpose(0, 1)[:,
relative_pos_index.view(-1)].view(-1, N, self.pool_len)
# Compute pooled similarity
attn_pool = q_norm_scaled @ F.normalize(k_pool, dim=-1).transpose(-2, -1) + pool_bias
# Concatenate local & pooled similarity matrices and calculate attention weights through the same Softmax
attn = torch.cat([attn_local, attn_pool], dim=-1).softmax(dim=-1)
attn = self.attn_drop(attn)
# Split the attention weights and separately aggregate the values of local & pooled features
attn_local, attn_pool = torch.split(attn, [self.local_len, self.pool_len], dim=-1)
x_local = (((q_norm @ self.learnable_tokens) + self.learnable_bias + attn_local).unsqueeze(
-2) @ v_local.transpose(-2, -1)).squeeze(-2)
x_pool = attn_pool @ v_pool
if self.routed_head > 0:
x = (x_local + x_pool).transpose(1, 2) # B, N, head, dim
if self.shared_head > 1:
shared_head_weight = self.wg_1(_x)
shared_head_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head
else:
shared_head_gates = torch.ones((B, N, self.shared_head)).to(_x.device).to(_x.dtype) * self.shared_head
if self.shared_head == 0:
masked_gates = routed_head_gates
else:
weight_0 = self.wg_0(_x)
weight_0 = F.softmax(weight_0, dim=1).reshape(B, N, 2) * 2
shared_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,0], shared_head_gates)
routed_head_gates = torch.einsum("bn,bne->bne", weight_0[:,:,1], routed_head_gates)
masked_gates = torch.cat([shared_head_gates, routed_head_gates], dim=2)
x = torch.einsum("bne,bned->bned", masked_gates, x)
x = x.reshape(B, N, C)
else:
shared_head_weight = self.wg_1(_x)
masked_gates = F.softmax(shared_head_weight, dim=1).reshape(B, N, -1) * self.shared_head
x = (x_local + x_pool).transpose(1, 2) # B, N, head, dim
x = torch.einsum("bne,bned->bned", masked_gates, x)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x