diff --git a/model.py b/model.py deleted file mode 100644 index ac6f67c..0000000 --- a/model.py +++ /dev/null @@ -1,512 +0,0 @@ -import torch -import torch.nn as nn -import torchvision -import math -import torch.nn.functional as F -from torch.cuda.amp import autocast -from torchvision.models._utils import IntermediateLayerGetter -import timm -import numpy as np - -from common import scTransformerLayer, scTransformerEncoder, PositionEncodingSine, PSP - -# sys.path.append("..") -# from visual import gem - -""" -ConvNext model name -{ - 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k', - 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k', - 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k', - 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k', - 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k', - 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384', - 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384', - 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384', - 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384', - 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384', - 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k', - 'convnext_small_in22k': 'convnext_small.fb_in22k', - 'convnext_base_in22k': 'convnext_base.fb_in22k', - 'convnext_large_in22k': 'convnext_large.fb_in22k', - 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k', - - 'convnextv2_tiny_22k_224_ema': 'convnextv2_tiny.fcmae_ft_in22k_in1k' - 'convnextv2_tiny_22k_384_ema': 'convnextv2_tiny.fcmae_ft_in22k_in1k_384' -} -) - -""" - -class Backbone(nn.Module): - def __init__(self, model_name, bk_checkpoint, return_interm_layers: bool, img_size=[122, 671]): - super().__init__() - self.name = model_name - # print('\nname\n', name) - if 'resnet' in self.name.lower(): - backbone = getattr(torchvision.models, self.name.lower())(weights='{}_Weights.IMAGENET1K_V1'.format(self.name)) - assert self.name in ('ResNet18', 'ResNet34', 'ResNet50'), "number of channels are hard coded" - - if return_interm_layers: - # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} - return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} - self.strides = [8, 16, 32] - if self.name == 'ResNet50': - self.num_channels = [512, 1024, 2048] - else: # resnet18 / resnet34 - self.num_channels = [128, 256, 512] - else: - return_layers = {'layer4': "0"} - self.strides = [32] - if self.name == 'ResNet50': - self.num_channels = [2048] - else: # resnet18 / resnet34 - self.num_channels = [512] - self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) - self.data_config = None - elif 'convnext' in self.name.lower(): - self.backbone = timm.create_model(self.name, pretrained=True, num_classes = 0, pretrained_cfg_overlay=dict(file=bk_checkpoint)) - self.data_config = timm.data.resolve_model_data_config(self.backbone) - if return_interm_layers: - self.strides = [8, 16, 32] - if 'base' in self.name.lower(): - self.num_channels = [256, 512, 1024] - elif 'tiny' in self.name.lower(): - self.num_channels = [192, 384, 768] - else: - self.strides = [32] - self.num_channels = [1024] - - else: - raise RuntimeError(f'error model_name [resnet* or convnext]') - - def forward(self, x): - if 'resnet' in self.name.lower(): - xs = self.backbone(x) - out = [] - for _, x in xs.items(): - out.append(x) - if 'convnext' in self.name.lower(): - x = self.backbone.stem(x) - x0 = self.backbone.stages[0](x) - - x1 = self.backbone.stages[1](x0) - x2 = self.backbone.stages[2](x1) - x3 = self.backbone.stages[3](x2) - - out = [x1, x2, x3] - return out - - -class BackboneEmbed(nn.Module): - def __init__(self, d_model, backbone_strides, backbone_num_channels, return_interm_layers: bool): - super().__init__() - self.return_interm_layers = return_interm_layers - - self.d_model = 128 - self.pos_embed = PositionEncodingSine(d_model=self.d_model) - - if self.return_interm_layers: - num_backbone_outs = len(backbone_strides) + 1 - input_proj_list = [] - for n in range(num_backbone_outs): - if n == num_backbone_outs - 1: - in_channels = backbone_num_channels[n - 1] - input_proj_list.append(nn.Sequential( - nn.Conv2d(in_channels, self.d_model, kernel_size=3, stride=2, padding=1), - nn.GroupNorm(32, self.d_model))) - else: - in_channels = backbone_num_channels[n] - input_proj_list.append(nn.Sequential( - nn.Conv2d(in_channels, self.d_model, kernel_size=1), - nn.GroupNorm(32, self.d_model))) - - self.input_proj = nn.ModuleList(input_proj_list) - else: - self.input_proj = nn.ModuleList([ - nn.Sequential( - nn.Conv2d(backbone_num_channels[0], self.d_model, kernel_size=1), - nn.GroupNorm(32, self.d_model), - )]) - - - def forward(self, features): - feats_embed = [] - srcs = [] - for l, feat in enumerate(features): - src = self.input_proj[l](feat) - srcs.append(src) - p = self.pos_embed(src) - feats_embed.append(p) - if self.return_interm_layers: - src = self.input_proj[-1](features[-1]) - srcs.append(src) - p = self.pos_embed(src) - feats_embed.append(p) - return feats_embed, srcs - -def weights_init_kaiming(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') - nn.init.constant_(m.bias, 0.0) - elif classname.find('BatchNorm') != -1: - if m.affine: - nn.init.constant_(m.weight, 1.0) - nn.init.constant_(m.bias, 0.0) - -class TimmModel(nn.Module): - def __init__(self, model_name, - sat_size, - grd_size, - psm=True, - is_polar=True): - - super(TimmModel, self).__init__() - - self.is_polar = is_polar - self.backbone_name = model_name - - self.d_model = 128 - self.nheads = 4 - self.nlayers = 2 - self.ffn_dim = 1024 - self.dropout = 0.3 - self.em_dim = 4096 // 2 - - self.activation = nn.GELU() - self.single_features = False - - self.sat_size = sat_size - self.grd_size = grd_size - - self.sample = psm - - if 'tiny' in self.backbone_name: - if 'v2' in self.backbone_name: - self.bk_checkpoint = '../pretrained/convnextv2_tiny_22k_224_ema.pt' - else: - self.bk_checkpoint = '../pretrained/convnext_tiny_22k_1k_224.pth' - elif 'base' in self.backbone_name: - if 'v2' in self.backbone_name: - self.bk_checkpoint = '../pretrained/convnextv2_base_22k_224_ema.pt' - else: - self.bk_checkpoint = '../pretrained/convnext_base_22k_1k_224.pth' - else: - self.bk_checkpoint = None - - if '384' in self.backbone_name: - self.bk_checkpoint = self.bk_checkpoint.replace('224', '384') - - if self.is_polar: - if self.sample: - self.norm1 = nn.LayerNorm(self.d_model) - self.norm2 = nn.LayerNorm(self.d_model) - self.sample_L_sat = PSP(sizes=[(1, 1), (6, 6), (12, 12), (21, 21)], dimension=2) - self.sample_L_grd = PSP(sizes=[(1, 1), (6, 6), (12, 12), (21, 21)], dimension=2) - self.in_dim_L = 622 - else: - if self.sample: - self.norm1 = nn.LayerNorm(self.d_model) - self.norm2 = nn.LayerNorm(self.d_model) - self.sample_L_sat = PSP(sizes=[(1, 1), (6, 6), (12, 12), (21, 21)], dimension=2) - self.sample_L_grd = PSP(sizes=[(1, 1), (3, 12), (6, 24), (7, 63)], dimension=2) - self.in_dim_L = 622 - - #----------------------- global -----------------------# - # Backbone - self.backbone = Backbone(self.backbone_name, self.bk_checkpoint, return_interm_layers=not self.single_features) - - # Position and embed - self.embed = BackboneEmbed(self.d_model, self.backbone.strides, self.backbone.num_channels, return_interm_layers=not self.single_features) - - - # multi-scale self-cross attention for sat - layer_sat_H = scTransformerLayer(self.d_model, self.nheads, self.ffn_dim, self.dropout, activation=self.activation, is_ffn=True) - self.transformer_sat_H = scTransformerEncoder(layer_sat_H, num_layers=2) - layer_sat_L = scTransformerLayer(self.d_model, self.nheads, self.ffn_dim, self.dropout, activation=self.activation, is_ffn=True, q_low=True) - self.transformer_sat_L = scTransformerEncoder(layer_sat_L, num_layers=1) - - # multi-scale self-cross attention for grd - layer_grd_H = scTransformerLayer(self.d_model, self.nheads, self.ffn_dim, self.dropout, activation=self.activation, is_ffn=True) - self.transformer_grd_H = scTransformerEncoder(layer_grd_H, num_layers=2) - layer_grd_L = scTransformerLayer(self.d_model, self.nheads, self.ffn_dim, self.dropout, activation=self.activation, is_ffn=True, q_low=True) - self.transformer_grd_L = scTransformerEncoder(layer_grd_L, num_layers=1) - - out_dim_g = 14 - - self.feat_dim_sat, self.H_sat, self.W_sat = self._dim(self.backbone_name, self.backbone.strides, img_size=self.sat_size) - in_dim_sat = sum(self.feat_dim_sat[1:]) + self.in_dim_L if self.sample else sum(self.feat_dim_sat) - self.proj_sat = nn.Linear(in_dim_sat, out_dim_g) - - - self.feat_dim_grd, self.H_grd, self.W_grd = self._dim(self.backbone_name, self.backbone.strides, img_size=self.grd_size) - in_dim_grd = sum(self.feat_dim_grd[1:]) + self.in_dim_L if self.sample else sum(self.feat_dim_grd) - self.proj_grd = nn.Linear(in_dim_grd, out_dim_g) - - - #----------------------- local -----------------------# - self.num_channles = self.backbone.num_channels - self.num_channles.append(self.d_model) - - ratio = 1 - proj_gl_sat = nn.ModuleList(nn.Sequential( - nn.Conv1d(self.d_model, self.d_model*ratio, kernel_size=self.k_size(self.d_model), padding=(self.k_size(self.d_model) - 1) // 2), - nn.BatchNorm1d(self.d_model*ratio), - nn.Conv1d(self.d_model*ratio, self.num_channles[i], kernel_size=self.k_size(self.d_model*ratio), padding=(self.k_size(self.d_model*ratio) - 1) // 2), - nn.GELU(), - nn.BatchNorm1d(self.num_channles[i]) - ) for i in range(len(self.num_channles))) - - proj_gl_grd = nn.ModuleList(nn.Sequential( - nn.Conv1d(self.d_model, self.d_model * ratio, kernel_size=self.k_size(self.d_model), padding=(self.k_size(self.d_model) - 1) // 2), - nn.BatchNorm1d(self.d_model*ratio), - nn.Conv1d(self.d_model*ratio, self.num_channles[i], kernel_size=self.k_size(self.d_model*ratio), padding=(self.k_size(self.d_model*ratio) - 1) // 2), - nn.GELU(), - nn.BatchNorm1d(self.num_channles[i]) - ) for i in range(len(self.num_channles))) - - proj_gl_sat.apply(weights_init_kaiming) - proj_gl_grd.apply(weights_init_kaiming) - - self.proj_gl_sat = proj_gl_sat - self.proj_gl_grd = proj_gl_grd - ch_sat = [nn.Conv2d(self.num_channles[i], self.num_channles[i], kernel_size=1) for i in range(len(self.H_sat))] - ch_grd = [nn.Conv2d(self.num_channles[i], self.num_channles[i], kernel_size=1) for i in range(len(self.H_grd))] - self.ch_sat = nn.Sequential(*ch_sat) - self.ch_grd = nn.Sequential(*ch_grd) - - if not self.is_polar: - sat_k = [9, 7, 5, 3] - grd_k = [(7, 13), (5, 11), (3, 9), (1, 7)] - pad = [(3, 6), (2, 5), (1, 4), (0, 3)] - sp_sat = [nn.Conv2d(1, 1, kernel_size=sat_k[i], padding=(sat_k[i] - 1) // 2) for i in range(len(self.num_channles))] - sp_grd = [nn.Conv2d(1, 1, kernel_size=(grd_k[0][0], grd_k[0][1]), padding=(pad[0][0], pad[0][1])) for i in range(len(self.num_channles))] - - if self.sample: - sp_sat[0] = nn.Conv2d(1, 1, kernel_size=(sat_k[0]*sat_k[0], 1), padding=((sat_k[0]*sat_k[0] - 1) // 2, 0)) - sp_grd[0] = nn.Conv2d(1, 1, kernel_size=(grd_k[0][0]*grd_k[0][1], 1), padding=((grd_k[0][0]*grd_k[0][1] - 1) // 2, 0)) - else: - sp_k = [(7, 13), (5, 11), (3, 9), (1, 7)] - pad = [(3, 6), (2, 5), (1, 4), (0, 3)] - sp_sat = [nn.Conv2d(1, 1, kernel_size=(sp_k[0][0], sp_k[0][1]), padding=(pad[0][0], pad[0][1])) for i in range(len(self.num_channles))] - sp_grd = [nn.Conv2d(1, 1, kernel_size=(sp_k[0][0], sp_k[0][1]), padding=(pad[0][0], pad[0][1])) for i in range(len(self.num_channles))] - - if self.sample: - sp_sat[0] = nn.Conv2d(1, 1, kernel_size=(sp_k[0][0]*sp_k[0][1], 1), padding=((sp_k[0][0]*sp_k[0][1] - 1) // 2, 0)) - sp_grd[0] = nn.Conv2d(1, 1, kernel_size=(sp_k[0][0]*sp_k[0][1], 1), padding=((sp_k[0][0]*sp_k[0][1] - 1) // 2, 0)) - self.sp_sat = nn.Sequential(*sp_sat) - self.sp_grd = nn.Sequential(*sp_grd) - - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.sigmoid = nn.Sigmoid() - - out_dim_l = 256 - self.proj_local_sat = nn.Linear(sum(self.num_channles), out_dim_l) - self.proj_local_grd = nn.Linear(sum(self.num_channles), out_dim_l) - - self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - def get_config(self,): - data_config = self.backbone.data_config - return data_config - - - def set_grad_checkpointing(self, enable=True): - self.model.set_grad_checkpointing(enable) - - def k_size(self, in_dim): - t = int(abs((math.log(in_dim, 2) + 1) / 2)) - k_size = t if t % 2 else t + 1 - - return k_size - - # @autocast() - def forward(self, img1, img2=None): - if img2 is not None: - grd_b = img1.shape[0] - sat_b = img2.shape[0] - sat_x = self.backbone(img2) - grd_x = self.backbone(img1) - - sat_e, sat_src = self.embed(sat_x) - grd_e, grd_src = self.embed(grd_x) - - # global - sat_embed = [x.flatten(2).transpose(1, 2) for x in sat_e] - grd_embed = [x.flatten(2).transpose(1, 2) for x in grd_e] - - # get low / high-level feature - if self.sample: - L_sat_embed = self.sample_L_sat(sat_e[0]) - L_sat_embed = self.norm1(L_sat_embed.flatten(2).transpose(1, 2)) - L_grd_embed = self.sample_L_grd(grd_e[0]) - L_grd_embed = self.norm2(L_grd_embed.flatten(2).transpose(1, 2)) - sat_x[0] = self.sample_L_sat(sat_x[0]) - grd_x[0] = self.sample_L_grd(grd_x[0]) - else: - L_sat_embed = sat_embed[0] - L_grd_embed = grd_embed[0] - - H_sat_embed = torch.cat(sat_embed[1:], 1) - H_grd_embed = torch.cat(grd_embed[1:], 1) - - # grd <-> sat multi-scale cross attention - sat_H, sat_L = self.transformer_sat_H(H_sat_embed, L_sat_embed) - sat_L, sat_H = self.transformer_sat_L(sat_L, sat_H) - - grd_H, grd_L = self.transformer_grd_H(H_grd_embed, L_grd_embed) - grd_L, grd_H = self.transformer_grd_L(grd_L, grd_H) - - sat = torch.cat([sat_L, sat_H], dim=1) # (B, L_sat, d_model) - grd = torch.cat([grd_L, grd_H], dim=1) # (B, L_grd, d_model) - - sat_global = self.proj_sat(sat.transpose(1, 2)).contiguous().view(sat_b,-1) - grd_global = self.proj_grd(grd.transpose(1, 2)).contiguous().view(grd_b,-1) - - - # local - sat_h1, sat_h2, sat_h3 = self._reshape_feat(sat_H, self.H_sat[1:], self.W_sat[1:]) - grd_h1, grd_h2, grd_h3 = self._reshape_feat(grd_H, self.H_grd[1:], self.W_grd[1:]) - - sat_x.append(sat_src[-1]) - grd_x.append(grd_src[-1]) - - sat_local = self._geo_att(sat_x, [sat_L, sat_h1, sat_h2, sat_h3], proj=self.proj_gl_sat, ch_att=self.ch_sat, sp_att=self.sp_sat, h=self.H_sat, w=self.W_sat) - grd_local = self._geo_att(grd_x, [grd_L, grd_h1, grd_h2, grd_h3], proj=self.proj_gl_grd, ch_att=self.ch_grd, sp_att=self.sp_grd, h=self.H_grd, w=self.W_grd) - - sat_local = self.proj_local_sat(sat_local) - grd_local = self.proj_local_grd(grd_local) - - desc_sat = torch.cat([sat_global, sat_local], dim=-1) - desc_grd = torch.cat([grd_global, grd_local], dim=-1) - - desc_sat = F.normalize(desc_sat.contiguous(), p=2, dim=1) - desc_grd = F.normalize(desc_grd.contiguous(), p=2, dim=1) - - return desc_sat.contiguous(), desc_grd.contiguous() - - else: - b, _, h, w = img1.shape - if h == w: - sat_x = self.backbone(img1) - - sat_e, sat_src = self.embed(sat_x) - # global - sat_embed = [x.flatten(2).transpose(1, 2) for x in sat_e] - - - # get low / high-level feature - if self.sample: - L_sat_embed = self.sample_L_sat(sat_e[0]) - L_sat_embed = self.norm1(L_sat_embed.flatten(2).transpose(1, 2)) - sat_x[0] = self.sample_L_sat(sat_x[0]) - else: - L_sat_embed = sat_embed[0] - H_sat_embed = torch.cat(sat_embed[1:], 1) - - # grd <-> sat multi-scale cross attention - sat_H, sat_L = self.transformer_sat_H(H_sat_embed, L_sat_embed) - sat_L, sat_H = self.transformer_sat_L(sat_L, sat_H) - sat = torch.cat([sat_L, sat_H], dim=1) # (B, L_sat, d_model) - sat_global = self.proj_sat(sat.transpose(1, 2)).contiguous().view(b,-1) - - - # local - sat_h1, sat_h2, sat_h3 = self._reshape_feat(sat_H, self.H_sat[1:], self.W_sat[1:]) - sat_x.append(sat_src[-1]) - sat_local = self._geo_att(sat_x, [sat_L, sat_h1, sat_h2, sat_h3], proj=self.proj_gl_sat, ch_att=self.ch_sat, sp_att=self.sp_sat, h=self.H_sat, w=self.W_sat) - sat_local = self.proj_local_sat(sat_local) - - desc_sat = torch.cat([sat_global, sat_local], dim=-1) - desc_sat = F.normalize(desc_sat.contiguous(), p=2, dim=1) - - return desc_sat - - else: - grd_x = self.backbone(img1) - grd_e, grd_src = self.embed(grd_x) - - # global - grd_embed = [x.flatten(2).transpose(1, 2) for x in grd_e] - - # get low / high-level feature - if self.sample: - L_grd_embed = self.sample_L_grd(grd_e[0]) - L_grd_embed = self.norm2(L_grd_embed.flatten(2).transpose(1, 2)) - grd_x[0] = self.sample_L_grd(grd_x[0]) - else: - L_grd_embed = grd_embed[0] - - H_grd_embed = torch.cat(grd_embed[1:], 1) - - # grd <-> sat multi-scale cross attention - grd_H, grd_L = self.transformer_grd_H(H_grd_embed, L_grd_embed) - grd_L, grd_H = self.transformer_grd_L(grd_L, grd_H) - grd = torch.cat([grd_L, grd_H], dim=1) # (B, L_grd, d_model) - grd_global = self.proj_grd(grd.transpose(1, 2)).contiguous().view(b,-1) - - # local - grd_h1, grd_h2, grd_h3 = self._reshape_feat(grd_H, self.H_grd[1:], self.W_grd[1:]) - grd_x.append(grd_src[-1]) - grd_local = self._geo_att(grd_x, [grd_L, grd_h1, grd_h2, grd_h3], proj=self.proj_gl_grd, ch_att=self.ch_grd, sp_att=self.sp_grd, h=self.H_grd, w=self.W_grd) - grd_local = self.proj_local_grd(grd_local) - - desc_grd = torch.cat([grd_global, grd_local], dim=-1) - desc_grd = F.normalize(desc_grd.contiguous(), p=2, dim=1) - - return desc_grd - - - def _reshape_feat(self, feat_H, H, W): - p1 = H[0] * W[0] - p2 = H[-1] * W[-1] - feat_h1 = feat_H[:, :p1, :].contiguous() - feat_h2 = feat_H[:, p1:-p2, :].contiguous() - feat_h3 = feat_H[:, -p2:, :].contiguous() - - return [feat_h1, feat_h2, feat_h3] - - - def _geo_att(self, local_feats, global_feats, proj, ch_att, sp_att, h, w): - geo_att = [] - for i, feat in enumerate(local_feats): - global_feat = proj[i](global_feats[i].transpose(1, 2)) - b, c, _ = global_feat.shape - if self.sample and i == 0: - feat = feat.unsqueeze(-1) - global_feat = global_feat.unsqueeze(-1) - else: - global_feat = global_feat.reshape(b, c, h[i], w[i]) - - # channels attrion - avg_out = self.avg_pool(global_feat) - att_ch = ch_att[i](avg_out) - - # spatial attrion - max_out, _ = torch.max(global_feat, dim=1, keepdim=True) - att_sp = sp_att[i](max_out) - m = feat * self.sigmoid(att_ch) * self.sigmoid(att_sp) - m = feat + m - m = self.avg_pool(m) - - geo_att.append(m.view(b, -1)) - - results = torch.cat(geo_att, dim=-1).contiguous() - return results - - - def _dim(self, model_name, strides, img_size=[122, 671]): - if 'convnext' in model_name.lower(): - H = [math.floor(img_size[0] / r) for r in strides] - W = [math.floor(img_size[1] / r) for r in strides] - feat_dim = [H[i] * W[i] for i in range(len(H))] - elif 'resnet' in model_name.lower(): - H = [math.ceil(img_size[0] / r) for r in strides] - W = [math.ceil(img_size[1] / r) for r in strides] - feat_dim = [H[i] * W[i] for i in range(len(H))] - H.append(math.ceil(H[-1] / 2)) - W.append(math.ceil(W[-1] / 2)) - feat_dim.append(H[-1] * W[-1]) - return feat_dim, H, W \ No newline at end of file