Skip to content

Commit

Permalink
first_commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Brilliant-B committed Aug 5, 2023
1 parent bcdd760 commit 02ead2b
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 398 deletions.
15 changes: 8 additions & 7 deletions configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path_rate=0.2,

init_values=None,
use_checkpoint=False,
use_abs_pos_emb=True,
Expand All @@ -43,23 +43,24 @@
pt_hw_seq_len=16,
intp_freq=True,
subln=True,
xattn=True,
xattn=False,
naiveswiglu=True,
pretrained='pretrained/eva02_L_pt_m38m_p14to16.pt',
),
decode_head=dict(
type='SegmenterHead_maskT',
img_size_ori=224, # 重要!
d_encoder=1024, # 重要!
n_layers=2,
n_heads=12,
d_model=768, # 重要!
d_ff=4*768,
d_ff=4 * 768,
drop_path_rate=0.0,
dropout=0.1,
in_channels=224,
channels=512,
in_index=0,

input_transform="resize_concat",
in_channels=(1024, 1024, 1024, 1024), # 重要!
in_index=(0, 1, 2, 3),
channels=1024,
dropout_ratio=0, # no relation
num_classes=60,
norm_cfg=norm_cfg,
Expand Down
43 changes: 13 additions & 30 deletions mmseg/models/backbones/EVA2.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,19 @@
# --------------------------------------------------------
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
# Github source: https://github.com/microsoft/unilm/tree/master/beit
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# By Hangbo Bao
# Based on timm, mmseg, setr, xcit and swin code bases
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/fudan-zvg/SETR
# https://github.com/facebookresearch/xcit/
# https://github.com/microsoft/Swin-Transformer
# --------------------------------------------------------'

import torch
from functools import partial
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint

from functools import partial
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from apex.normalization import FusedLayerNorm
from math import pi
from einops import rearrange, repeat
# import xformers.ops as xops

from mmcv_custom import load_checkpoint
from mmseg.utils import get_root_logger
from mmseg.models.builder import BACKBONES

import xformers.ops as xops
from apex.normalization import FusedLayerNorm

from math import pi
from einops import rearrange, repeat


def broadcat(tensors, dim = -1):
Expand Down Expand Up @@ -56,6 +42,7 @@ def rotate_half(x):



'''
class VisionRotaryEmbedding(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -103,7 +90,7 @@ def forward(self, t, start_index = 0):
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
return torch.cat((t_left, t, t_right), dim = -1)

'''



Expand Down Expand Up @@ -140,11 +127,9 @@ def __init__(

freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])

self.register_buffer("freqs_cos", freqs_cos)
self.register_buffer("freqs_sin", freqs_sin)


def forward(self, t):
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin

Expand Down Expand Up @@ -312,7 +297,7 @@ def forward(self, x, rel_pos_bias=None):
ro_k_t = self.rope(k_t)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)

if self.xattn:
if False:
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
Expand Down Expand Up @@ -538,7 +523,7 @@ def __init__(
out_indices=[3, 5, 7, 11],

subln=True,
xattn=True,
xattn=False,
naiveswiglu=True,
rope=True,
pt_hw_seq_len=16,
Expand Down Expand Up @@ -584,6 +569,7 @@ def __init__(
)
else: self.rope = None
self.naiveswiglu = naiveswiglu
self.pretrained = pretrained

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
Expand Down Expand Up @@ -673,7 +659,7 @@ def _init_weights(m):
def get_num_layers(self):
return len(self.blocks)

# @torch.jit.ignore
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}

Expand Down Expand Up @@ -704,11 +690,8 @@ def forward_features(self, x):
if len(features) == 1:
for i in range(len(ops) - 1):
features.append(features[0])
for i in range(len(features)):
features[i] = ops[i](features[i])
else:
for i in range(len(features)):
features[i] = ops[i](features[i])
for i in range(len(features)):
features[i] = ops[i](features[i])

return tuple(features)

Expand Down
Loading

0 comments on commit 02ead2b

Please sign in to comment.