Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ytao-wang authored Mar 6, 2024
1 parent 5d38a1e commit bb45fc5
Show file tree
Hide file tree
Showing 2 changed files with 724 additions and 0 deletions.
212 changes: 212 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import math
import torch
import copy
from torch import nn

import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, constant_


class scTransformerLayer(nn.Module):
def __init__(self, d_model, nheads, dim_feedforward, dropout, is_ffn=True, qk_cat=True, q_low=False, activation=nn.ReLU(inplace=True), mode='linear'):
super().__init__()
self.is_ffn = is_ffn
self.qk_cat = qk_cat
self.q_low = q_low
self.dim = d_model // nheads
self.nheads = nheads
self.query_proj = nn.Linear(d_model, d_model)
self.key_proj = nn.Linear(d_model, d_model)
self.value_proj = nn.Linear(d_model, d_model)
self.merge = nn.Linear(d_model, d_model, bias=False)
self.mode = mode

assert self.mode in ('linear', 'full', 'conv1'), "scTransformerLayer mode error (set linear, full or conv1)"

# self-cross attention
self.sc_attn = LinearAttention()
self.norm1 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)

self.activation = activation


# feed-forward network
# print('mode', mode)
if self.is_ffn:
if mode == 'linear' or 'full':
self.mlp = nn.Sequential(
nn.Linear(d_model, dim_feedforward, bias=False),
self.activation,
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model, bias=False)
)
elif mode == 'conv1':
t1 = int(abs((math.log(d_model, 2) + 1) / 2))
k1_size = t1 if t1 % 2 else t1 + 1

t2 = int(abs((math.log(dim_feedforward, 2) + 1) / 2))
k2_size = t2 if t2 % 2 else t2 + 1
self.mlp = nn.Sequential(
nn.Conv1d(d_model, dim_feedforward, kernel_size=k1_size, padding=(k1_size - 1) // 2, bias=False),
self.activation,
nn.Dropout(dropout),
nn.Conv1d(dim_feedforward, d_model, kernel_size=k2_size, padding=(k2_size - 1) // 2, bias=False),
)

self.norm2 = nn.LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)

self._reset_parameters()


def _reset_parameters(self):
xavier_uniform_(self.query_proj.weight.data)
constant_(self.query_proj.bias.data, 0.)

xavier_uniform_(self.key_proj.weight.data)
constant_(self.key_proj.bias.data, 0.)

xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.)


def forward(self, q_src, kv_src, q_mask=None, kv_mask=None):
"""
Args:
q_src (torch.Tensor): [N, L, C]
k_src (torch.Tensor): [N, S, C]
q_mask (torch.Tensor): [N, L] (optional)
kv_mask (torch.Tensor): [N, S] (optional)
"""
b = q_src.shape[0]
query = self.query_proj(q_src).view(b, -1, self.nheads, self.dim) # [N, L, (H, D)]

if self.qk_cat:
if not self.q_low:
key = value = torch.cat([kv_src, q_src], dim=1)
else:
key = value = torch.cat([q_src, kv_src], dim=1)
else:
key = value = kv_src
key = self.key_proj(key).view(b, -1, self.nheads, self.dim)
value = self.value_proj(value).view(b, -1, self.nheads, self.dim)

att = self.sc_attn(query, key, value, q_mask, kv_mask) # [N, L, (H, D)]
att = self.merge(att.view(b, -1, self.nheads * self.dim)) # [N, L, C]

att = self.norm1(q_src + self.dropout1(att))

# ffn
if self.is_ffn:
if self.mode == 'conv1':
att = att.permute(0, 2, 1)
att = att + self.dropout2(self.mlp(att))
if self.mode == 'conv1':
att = att.permute(0, 2, 1)
att = self.norm2(att)

return att, kv_src


class PositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""

def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
on the final performance. For now, we keep both impls for backward compatability.
We will remove the buggy impl after re-training all variants of our released models.
"""
super().__init__()

pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
if temp_bug_fix:
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
else: # a buggy implementation (for backward compatability only)
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)

self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]

def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, :x.size(2), :x.size(3)]

def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1

# from LoFTR
class LinearAttention(nn.Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps

def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)

# set padded position to zero
if q_mask is not None:
Q = Q * q_mask[:, :, None, None]
if kv_mask is not None:
K = K * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]

v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length

return queried_values.contiguous()

class PSP(nn.Module):
# (1, 2, 3, 6)
def __init__(self, sizes=[(1, 1), (3, 3), (6, 6), (8, 8)], dimension=2):
super(PSP, self).__init__()
self.stages = nn.ModuleList([self._make_stage(size, dimension) for size in sizes])

def _make_stage(self, size, dimension=2):
if dimension == 1:
prior = nn.AdaptiveAvgPool1d(output_size=size)
elif dimension == 2:
prior = nn.AdaptiveAvgPool2d(output_size=(size[0], size[1]))
elif dimension == 3:
prior = nn.AdaptiveAvgPool3d(output_size=(size[0], size[1], size[2]))
return prior

def forward(self, feats):
n, c, _, _ = feats.size()
priors = [stage(feats).view(n, c, -1) for stage in self.stages]
center = torch.cat(priors, -1)
return center
Loading

0 comments on commit bb45fc5

Please sign in to comment.