diff --git a/configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py b/configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py index 66599c0..7283371 100644 --- a/configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py +++ b/configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py @@ -33,7 +33,7 @@ drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.2, - + init_values=None, use_checkpoint=False, use_abs_pos_emb=True, @@ -43,23 +43,24 @@ pt_hw_seq_len=16, intp_freq=True, subln=True, - xattn=True, + xattn=False, naiveswiglu=True, pretrained='pretrained/eva02_L_pt_m38m_p14to16.pt', ), decode_head=dict( type='SegmenterHead_maskT', img_size_ori=224, # 重要! - d_encoder=1024, # 重要! n_layers=2, n_heads=12, d_model=768, # 重要! - d_ff=4*768, + d_ff=4 * 768, drop_path_rate=0.0, dropout=0.1, - in_channels=224, - channels=512, - in_index=0, + + input_transform="resize_concat", + in_channels=(1024, 1024, 1024, 1024), # 重要! + in_index=(0, 1, 2, 3), + channels=1024, dropout_ratio=0, # no relation num_classes=60, norm_cfg=norm_cfg, diff --git a/mmseg/models/backbones/EVA2.py b/mmseg/models/backbones/EVA2.py index 71669ae..deb82bc 100644 --- a/mmseg/models/backbones/EVA2.py +++ b/mmseg/models/backbones/EVA2.py @@ -1,33 +1,19 @@ -# -------------------------------------------------------- -# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) -# Github source: https://github.com/microsoft/unilm/tree/master/beit -# Copyright (c) 2021 Microsoft -# Licensed under The MIT License [see LICENSE for details] -# By Hangbo Bao -# Based on timm, mmseg, setr, xcit and swin code bases -# https://github.com/rwightman/pytorch-image-models/tree/master/timm -# https://github.com/fudan-zvg/SETR -# https://github.com/facebookresearch/xcit/ -# https://github.com/microsoft/Swin-Transformer -# --------------------------------------------------------' - import torch -from functools import partial import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint +from functools import partial from timm.models.layers import drop_path, to_2tuple, trunc_normal_ +from apex.normalization import FusedLayerNorm +from math import pi +from einops import rearrange, repeat +# import xformers.ops as xops from mmcv_custom import load_checkpoint from mmseg.utils import get_root_logger from mmseg.models.builder import BACKBONES -import xformers.ops as xops -from apex.normalization import FusedLayerNorm - -from math import pi -from einops import rearrange, repeat def broadcat(tensors, dim = -1): @@ -56,6 +42,7 @@ def rotate_half(x): +''' class VisionRotaryEmbedding(nn.Module): def __init__( self, @@ -103,7 +90,7 @@ def forward(self, t, start_index = 0): t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) return torch.cat((t_left, t, t_right), dim = -1) - +''' @@ -140,11 +127,9 @@ def __init__( freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) - self.register_buffer("freqs_cos", freqs_cos) self.register_buffer("freqs_sin", freqs_sin) - def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin @@ -312,7 +297,7 @@ def forward(self, x, rel_pos_bias=None): ro_k_t = self.rope(k_t) k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) - if self.xattn: + if False: q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) @@ -538,7 +523,7 @@ def __init__( out_indices=[3, 5, 7, 11], subln=True, - xattn=True, + xattn=False, naiveswiglu=True, rope=True, pt_hw_seq_len=16, @@ -584,6 +569,7 @@ def __init__( ) else: self.rope = None self.naiveswiglu = naiveswiglu + self.pretrained = pretrained dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias @@ -673,7 +659,7 @@ def _init_weights(m): def get_num_layers(self): return len(self.blocks) - # @torch.jit.ignore + @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token'} @@ -704,11 +690,8 @@ def forward_features(self, x): if len(features) == 1: for i in range(len(ops) - 1): features.append(features[0]) - for i in range(len(features)): - features[i] = ops[i](features[i]) - else: - for i in range(len(features)): - features[i] = ops[i](features[i]) + for i in range(len(features)): + features[i] = ops[i](features[i]) return tuple(features) diff --git a/mmseg/models/backbones/Mydeit.py b/mmseg/models/backbones/Mydeit.py deleted file mode 100644 index 99eeb3d..0000000 --- a/mmseg/models/backbones/Mydeit.py +++ /dev/null @@ -1,314 +0,0 @@ -import math -import warnings - -import torch -import torch.nn as nn -from mmcv.cnn import (build_norm_layer, constant_init, kaiming_init, - normal_init, trunc_normal_init) -from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention -from mmcv.runner import BaseModule, ModuleList, _load_checkpoint -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.utils import _pair as to_2tuple - -from mmseg.ops import resize -from mmseg.utils import get_root_logger -from ..builder import BACKBONES -from ..utils import PatchEmbed, vit_convert - -from mmseg.models.backbones.vit import TransformerEncoderLayer, VisionTransformer - -@BACKBONES.register_module() -class Deit(VisionTransformer): - """Vision Transformer variant. (Deit) - - A PyTorch implement of : `DeiT: Data-efficient Image Transformers` - - https://arxiv.org/abs/2012.12877 - - Compared with ViT: - a) add self.dist_token attribute beside cls_token - b) init method of self.pos_embed: torch.zeros->torch.randn - - Args: - img_size (int | tuple): Input image size. Default: 224. - patch_size (int): The patch size. Default: 16. - in_channels (int): Number of input channels. Default: 3. - embed_dims (int): embedding dimension. Default: 768. - num_layers (int): depth of transformer. Default: 12. - num_heads (int): number of attention heads. Default: 12. - mlp_ratio (int): ratio of mlp hidden dim to embedding dim. - Default: 4. - out_indices (list | tuple | int): Output from which stages. - Default: -1. - qkv_bias (bool): enable bias for qkv if True. Default: True. - drop_rate (float): Probability of an element to be zeroed. - Default 0.0 - attn_drop_rate (float): The drop out rate for attention layer. - Default 0.0 - drop_path_rate (float): stochastic depth rate. Default 0.0 - with_cls_token (bool): Whether concatenating class token into image - tokens as transformer input. Default: True. - output_cls_token (bool): Whether output the cls_token. If set True, - `with_cls_token` must be True. Default: False. - norm_cfg (dict): Config dict for normalization layer. - Default: dict(type='LN') - act_cfg (dict): The activation config for FFNs. - Defalut: dict(type='GELU'). - patch_norm (bool): Whether to add a norm in PatchEmbed Block. - Default: False. - final_norm (bool): Whether to add a additional layer to normalize - final feature map. Default: False. - interpolate_mode (str): Select the interpolate mode for position - embeding vector resize. Default: bicubic. - num_fcs (int): The number of fully-connected layers for FFNs. - Default: 2. - norm_eval (bool): Whether to set norm layers to eval mode, namely, - freeze running stats (mean and var). Note: Effect on Batch Norm - and its variants only. Default: False. - with_cp (bool): Use checkpoint or not. Using checkpoint will save - some memory while slowing down the training speed. Default: False. - pretrain_style (str): Choose to use timm or mmcls pretrain weights. - Default: timm. - pretrained (str, optional): model pretrained path. Default: None. - init_cfg (dict or list[dict], optional): Initialization config dict. - Default: None. - """ - - def __init__(self, - img_size=224, - patch_size=16, - in_channels=3, - embed_dims=768, - num_layers=12, - num_heads=12, - mlp_ratio=4, - out_indices=-1, - qkv_bias=True, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0., - with_cls_token=True, - output_cls_token=False, - norm_cfg=dict(type='LN'), - act_cfg=dict(type='GELU'), - patch_norm=False, - final_norm=False, - interpolate_mode='bicubic', - num_fcs=2, - norm_eval=False, - with_cp=False, - pretrain_style='timm', - pretrained=None, - init_cfg=None): - super(Deit, self).__init__() - - if isinstance(img_size, int): - img_size = to_2tuple(img_size) - elif isinstance(img_size, tuple): - if len(img_size) == 1: - img_size = to_2tuple(img_size[0]) - assert len(img_size) == 2, \ - f'The size of image should have length 1 or 2, ' \ - f'but got {len(img_size)}' - - assert pretrain_style in ['timm', 'mmcls'] - - if output_cls_token: - assert with_cls_token is True, f'with_cls_token must be True if' \ - f'set output_cls_token to True, but got {with_cls_token}' - - if isinstance(pretrained, str) or pretrained is None: - warnings.warn('DeprecationWarning: pretrained is a deprecated, ' - 'please use "init_cfg" instead') - else: - raise TypeError('pretrained must be a str or None') - - self.img_size = img_size - self.patch_size = patch_size - self.interpolate_mode = interpolate_mode - self.norm_eval = norm_eval - self.with_cp = with_cp - self.pretrain_style = pretrain_style - self.pretrained = pretrained - self.init_cfg = init_cfg - - self.patch_embed = PatchEmbed( - in_channels=in_channels, - embed_dims=embed_dims, - conv_type='Conv2d', - kernel_size=patch_size, - stride=patch_size, - pad_to_patch_size=True, - norm_cfg=norm_cfg if patch_norm else None, - init_cfg=None, - ) - - num_patches = (img_size[0] // patch_size) * \ - (img_size[1] // patch_size) - - self.with_cls_token = with_cls_token - self.output_cls_token = output_cls_token - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) - self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 2, embed_dims)) - self.drop_after_pos = nn.Dropout(p=drop_rate) - - if isinstance(out_indices, int): - if out_indices == -1: - out_indices = num_layers - 1 - self.out_indices = [out_indices] - elif isinstance(out_indices, list) or isinstance(out_indices, tuple): - self.out_indices = out_indices - else: - raise TypeError('out_indices must be type of int, list or tuple') - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)] # stochastic depth decay rule - - self.layers = ModuleList() - for i in range(num_layers): - self.layers.append( - TransformerEncoderLayer( - embed_dims=embed_dims, - num_heads=num_heads, - feedforward_channels=mlp_ratio * embed_dims, - attn_drop_rate=attn_drop_rate, - drop_rate=drop_rate, - drop_path_rate=dpr[i], - num_fcs=num_fcs, - qkv_bias=qkv_bias, - act_cfg=act_cfg, - norm_cfg=norm_cfg, - batch_first=True)) - - self.final_norm = final_norm - if final_norm: - self.norm1_name, norm1 = build_norm_layer( - norm_cfg, embed_dims, postfix=1) - self.add_module(self.norm1_name, norm1) - - def init_weights(self): - if isinstance(self.pretrained, str): - logger = get_root_logger() - checkpoint = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - else: - state_dict = checkpoint - - if self.pretrain_style == 'timm': - # Because the refactor of vit is blocked by mmcls, - # so we firstly use timm pretrain weights to train - # downstream model. - state_dict = vit_convert(state_dict) - - if 'pos_embed' in state_dict.keys(): - if self.pos_embed.shape != state_dict['pos_embed'].shape: - logger.info(msg=f'Resize the pos_embed shape from ' - f'{state_dict["pos_embed"].shape} to ' - f'{self.pos_embed.shape}') - h, w = self.img_size - pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1] - 1)) - state_dict['pos_embed'] = self.resize_pos_embed( - state_dict['pos_embed'], - (h // self.patch_size, w // self.patch_size), - (pos_size, pos_size), - self.interpolate_mode) - - self.load_state_dict(state_dict, False) - - @staticmethod - def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode): - """Resize pos_embed weights. - - Resize pos_embed using bicubic interpolate method. - Args: - pos_embed (torch.Tensor): Position embedding weights. - input_shpae (tuple): Tuple for (downsampled input image height, - downsampled input image width). - pos_shape (tuple): The resolution of downsampled origin training - image. - mode (str): Algorithm used for upsampling: - ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'``. Default: ``'nearest'`` - Return: - torch.Tensor: The resized pos_embed of shape [B, L_new, C] - """ - assert pos_embed.ndim == 3, 'shape of pos_embed must be [B, L, C]' - pos_h, pos_w = pos_shape - cls_token_weight = pos_embed[:, 0] - dist_token_weight = pos_embed[:, 1] - pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w):] - pos_embed_weight = pos_embed_weight.reshape( - 1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2) - pos_embed_weight = resize( - pos_embed_weight, size=input_shpae, align_corners=False, mode=mode) - cls_token_weight = cls_token_weight.unsqueeze(1) - dist_token_weight = dist_token_weight.unsqueeze(1) - pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2) - pos_embed = torch.cat((cls_token_weight, dist_token_weight, pos_embed_weight), dim=1) - return pos_embed - - def _pos_embeding(self, patched_img, hw_shape, pos_embed): - """Positiong embeding method. - - Resize the pos_embed, if the input image size doesn't match - the training size. - Args: - patched_img (torch.Tensor): The patched image, it should be - shape of [B, L1, C]. - hw_shape (tuple): The downsampled image resolution. - pos_embed (torch.Tensor): The pos_embed weighs, it should be - shape of [B, L2, c]. - Return: - torch.Tensor: The pos encoded image feature. - """ - assert patched_img.ndim == 3 and pos_embed.ndim == 3, 'the shapes of patched_img and pos_embed must be [B, L, C]' - x_len, pos_len = patched_img.shape[1], pos_embed.shape[1] - if x_len != pos_len: - if pos_len == (self.img_size[0] // self.patch_size) * (self.img_size[1] // self.patch_size) + 2: - pos_h = self.img_size[0] // self.patch_size - pos_w = self.img_size[1] // self.patch_size - else: - raise ValueError('Unexpected shape of pos_embed, got {}.'.format(pos_embed.shape)) - pos_embed = self.resize_pos_embed(pos_embed, hw_shape, - (pos_h, pos_w), - self.interpolate_mode) - return self.drop_after_pos(patched_img + pos_embed) - - def forward(self, inputs): - B = inputs.shape[0] - - x, hw_shape = self.patch_embed(inputs), (self.patch_embed.DH, - self.patch_embed.DW) - # stole cls_tokens impl from Phil Wang, thanks - cls_tokens = self.cls_token.expand(B, -1, -1) - dist_tokens = self.dist_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, dist_tokens, x), dim=1) - x = self._pos_embeding(x, hw_shape, self.pos_embed) - - if not self.with_cls_token: - # Remove class token for transformer encoder input - x = x[:, 2:] - - outs = [] - for i, layer in enumerate(self.layers): - x = layer(x) - if i == len(self.layers) - 1: - if self.final_norm: - x = self.norm1(x) - if i in self.out_indices: - if self.with_cls_token: - # Remove class token and reshape token for decoder head - out = x[:, 2:] - else: - out = x - # print(f'index {i}: output.shape = {out.shape}') - B, _, C = out.shape - out = out.reshape(B, hw_shape[0], hw_shape[1], - C).permute(0, 3, 1, 2) - if self.output_cls_token: - out = [out, x[:, 0]] - outs.append(out) - - return tuple(outs) diff --git a/mmseg/models/decode_heads/Segmenter_maskT_head.py b/mmseg/models/decode_heads/Segmenter_maskT_head.py index 056ef50..37c24f1 100644 --- a/mmseg/models/decode_heads/Segmenter_maskT_head.py +++ b/mmseg/models/decode_heads/Segmenter_maskT_head.py @@ -91,10 +91,9 @@ def forward(self, x, mask=None): @HEADS.register_module() class SegmenterHead_maskT(BaseDecodeHead): - def __init__(self, img_size_ori, d_encoder, n_layers, n_heads, d_model, d_ff, drop_path_rate, dropout, **kwargs): - super(SegmenterHead_maskT, self).__init__(input_transform=None, **kwargs) + def __init__(self, img_size_ori, n_layers, n_heads, d_model, d_ff, drop_path_rate, dropout, **kwargs): + super(SegmenterHead_maskT, self).__init__(**kwargs) self.img_size_ori = img_size_ori - self.d_encoder = d_encoder self.d_model = d_model self.d_ff = d_ff self.scale = d_model ** -0.5 @@ -105,7 +104,7 @@ def __init__(self, img_size_ori, d_encoder, n_layers, n_heads, d_model, d_ff, dr ) self.cls_emb = nn.Parameter(torch.randn(1, self.num_classes, d_model)) - self.proj_dec = nn.Linear(d_encoder, d_model) + self.proj_dec = nn.Linear(self.in_channels, d_model) self.proj_patch = nn.Parameter(self.scale * torch.randn(d_model, d_model)) self.proj_classes = nn.Parameter(self.scale * torch.randn(d_model, d_model)) @@ -119,10 +118,9 @@ def __init__(self, img_size_ori, d_encoder, n_layers, n_heads, d_model, d_ff, dr def forward(self, inputs): x = self._transform_inputs(inputs) GS = x.shape[-1] - x = rearrange(x, "b n h w -> b (h w) n") - x = self.proj_dec(x) + cls_emb = self.cls_emb.expand(x.size(0), -1, -1) x = torch.cat((x, cls_emb), 1) for blk in self.blocks: diff --git a/train.py b/train.py index 51d00ae..3c2fbca 100644 --- a/train.py +++ b/train.py @@ -1,45 +1,30 @@ import argparse import copy import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0, " import os.path as osp import time -import bitsandbytes as bnb +import re +import loralib as lora -import mmcv import torch import torch.nn as nn +import mmcv from mmcv.runner import init_dist from mmcv.utils import Config, DictAction, get_git_hash - from mmseg import __version__ from mmseg.apis import set_random_seed from mmcv_custom import train_segmentor from mmseg.datasets import build_dataset from mmseg.models import build_segmentor from mmseg.utils import collect_env, get_root_logger - from mmseg.models.backbones import EVA2 -import loralib as lora -from transformers import ( - PreTrainedModel, - PretrainedConfig, - AutoModelForCausalLM, - BitsAndBytesConfig, -) -from peft import ( - prepare_model_for_kbit_training, - LoraConfig, - get_peft_model, - PeftModel -) -from peft.tuners.lora import LoraLayer -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR def parse_args(): parser = argparse.ArgumentParser(description='Train a segmentor') - parser.add_argument('config', help='train config file path') + parser.add_argument('--config', help='train config file path') parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--load-from', help='the checkpoint file to load weights from') @@ -81,6 +66,8 @@ def parse_args(): return args + +''' def find_eva_linear_names(model): lora_module_names = set() for name, module in model.backbone.named_modules(): @@ -94,6 +81,22 @@ def find_eva_linear_names(model): def get_accelerate_model(args, model, checkpoint_dir=None): + from transformers import ( + PreTrainedModel, + PretrainedConfig, + AutoModelForCausalLM, + BitsAndBytesConfig, + ) + from peft import ( + prepare_model_for_kbit_training, + LoraConfig, + get_peft_model, + PeftModel + ) + from peft.tuners.lora import LoraLayer + from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR + import bitsandbytes as bnb + pconfig=PretrainedConfig(is_encoder_decoder=True,torch_dtype=torch.float32) prtr=PreTrainedModel(pconfig) # prtr.save_pretrained('workbench/pretrained/') @@ -146,10 +149,38 @@ def get_accelerate_model(args, model, checkpoint_dir=None): module = module.to(torch.bfloat16) return model +''' + + +def get_finetune_model(model, code, verbose=False): + def freeze_match(name, f_list): + ret = False + for n in f_list: + ret = ret or (re.search(n, name) is not None) + return ret + + freeze_list = [] + + if code == 1: + checkpoint = torch.load(model.backbone.pretrained, map_location='cpu')["model"] + freeze_list.extend([f"backbone.{key}" for key in checkpoint.keys()]) + + if verbose: + print("**frozen parameters**") + print(f"List: {freeze_list}") + for key, value in model.named_parameters(): + if freeze_match(key, freeze_list): + value.requires_grad = False + if verbose: + print(key) + return model -def main(): + + +def main(info, verbose=False): args = parse_args() + finetune_code = info["finetune_code"] cfg = Config.fromfile(args.config) if args.options is not None: @@ -164,8 +195,8 @@ def main(): cfg.work_dir = args.work_dir elif cfg.get('work_dir', None) is None: # use config filename as default work_dir if cfg.work_dir is None - cfg.work_dir = osp.join('./work_dirs', - osp.splitext(osp.basename(args.config))[0]) + cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + cfg.work_dir = osp.join(cfg.work_dir, f"finetune_{finetune_code}") if args.load_from is not None: cfg.load_from = args.load_from if args.resume_from is not None: @@ -174,7 +205,6 @@ def main(): cfg.gpu_ids = args.gpu_ids else: cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) - # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False @@ -198,16 +228,14 @@ def main(): env_info_dict = collect_env() env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) dash_line = '-' * 60 + '\n' - logger.info('Environment info:\n' + dash_line + env_info + '\n' + - dash_line) + if verbose: + logger.info('Environment info:\n' + dash_line + env_info + '\n' + dash_line) + logger.info(f'Config:\n{cfg.pretty_text}') meta['env_info'] = env_info - # log some basic info - logger.info(f'Distributed training: {distributed}') - logger.info(f'Config:\n{cfg.pretty_text}') - # set random seeds if args.seed is not None: + logger.info(f'Distributed training: {distributed}') logger.info(f'Set random seed to {args.seed}, deterministic: ' f'{args.deterministic}') set_random_seed(args.seed, deterministic=args.deterministic) @@ -220,11 +248,12 @@ def main(): train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg') ) - - # for k,v in model.named_parameters(): - # print('{}: {}'.format(k, v.requires_grad)) - # logger.info(model) - + # finetune_code: {0: non_freeze; 1: freeze_loaded_eva2} + if finetune_code > 0: + model = get_finetune_model(model, finetune_code, verbose=verbose) + if verbose: + logger.info(model) + datasets = [build_dataset(cfg.data.train)] if len(cfg.workflow) == 2: val_dataset = copy.deepcopy(cfg.data.val) @@ -252,4 +281,7 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + hyper_info = { + "finetune_code": 1, + } + main(hyper_info) # verbose=True \ No newline at end of file diff --git a/train.sh b/train.sh index ab1a2d9..be98e0e 100644 --- a/train.sh +++ b/train.sh @@ -1,9 +1,11 @@ #!/usr/bin/env bash -GPUS=4 +GPU=0 CONFIGS="configs/eva2_hybrid/Segmenter_EVA02_large_24_512_slide_80k.py" +WORK_DIR="workbench/eva2_segmenter/" +LOAD="" +RESUME="" -python -m torch.distributed.launch --nproc_per_node=${GPUS} \ - --use_env train.py --launcher pytorch \ - ${CONFIGS} --seed 0 --deterministic --gpus ${GPUS} \ - # --load-from workbench/iter_60000.pth \ No newline at end of file +python train.py --config ${CONFIGS} --seed 0 --deterministic --gpu-ids ${GPU} \ + --work-dir ${WORK_DIR} \ +# --load-from ${LOAD}