diff --git a/README.md b/README.md index ac7f9b6..90a6de7 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised, self-supervised, and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. If you use this toolkit, please cite the following paper: @@ -23,7 +23,7 @@ BibTeX entry: # Features PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: -* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, weakly-supervised and noisy-label learning. +* Support for annotation-efficient image segmentation, especially for semi-supervised, self-supervised, self-supervised, weakly-supervised and noisy-label learning. * User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. * Various data pre-processing/transformation methods before sending a tensor into a network. @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.4.0, run: +To install a specific version of PYMIC such as 0.4.1, run: ```bash -pip install PYMIC==0.4.0 +pip install PYMIC==0.4.1 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index 6da51b5..db9368a 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn +from pymic.loss.seg.abstract import AbstractSegLoss -class MumfordShahLoss(nn.Module): +class MumfordShahLoss(AbstractSegLoss): """ Implementation of Mumford Shah Loss for weakly supervised learning. @@ -76,8 +77,8 @@ def forward(self, loss_input_dict): image = loss_input_dict['image'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) pred_shape = list(predict.shape) if(len(pred_shape) == 5): diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index d5c4151..92adea7 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -38,8 +38,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index 0bf276f..3a7430a 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -34,8 +34,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 @@ -70,8 +70,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) # for numeric stability predict = predict * 0.999 + 5e-4 diff --git a/pymic/net/net3d/trans3d/HiFormer_v1.py b/pymic/net/net3d/trans3d/HiFormer_v1.py deleted file mode 100644 index af73683..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v1.py +++ /dev/null @@ -1,1010 +0,0 @@ -from einops import rearrange -from copy import deepcopy -from nnformer.utilities.nd_softmax import softmax_helper -from torch import nn -import torch -import numpy as np -import torch.nn.functional -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ -from pymic.net.net3d.unet3d import ConvBlock, DownBlock -# from nnFormer -class ContiguousGrad(torch.autograd.Function): - @staticmethod - def forward(ctx, x): - return x - @staticmethod - def backward(ctx, grad_out): - return grad_out.contiguous() - -# from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - -# from nnFormer -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - -# from nnFormer -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -# from nnFormer -class SwinTransformerBlock_kv(nn.Module): - - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention_kv( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - #self.window_size=to_3tuple(self.window_size) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - def forward(self, x, mask_matrix,skip=None,x_up=None): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - skip = self.norm1(skip) - x_up = self.norm1(x_up) - - skip = skip.view(B, S, H, W, C) - x_up = x_up.view(B, S, H, W, C) - x = x.view(B, S, H, W, C) - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - skip = F.pad(skip, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - x_up = F.pad(x_up, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = skip.shape - - - - # cyclic shift - if self.shift_size > 0: - skip = torch.roll(skip, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - x_up = torch.roll(x_up, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - skip = skip - x_up=x_up - attn_mask = None - # partition windows - skip = window_partition(skip, self.window_size) - skip = skip.view(-1, self.window_size * self.window_size * self.window_size, - C) - x_up = window_partition(x_up, self.window_size) - x_up = x_up.view(-1, self.window_size * self.window_size * self.window_size, - C) - attn_windows=self.attn(skip,x_up,mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class WindowAttention_kv(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.softmax = nn.Softmax(dim=-1) - trunc_normal_(self.relative_position_bias_table, std=.02) - - - def forward(self, skip,x_up,pos_embed=None, mask=None): - - B_, N, C = skip.shape - - kv = self.kv(skip) - q = x_up - - kv=kv.reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q = q.reshape(B_,N,self.num_heads,C//self.num_heads).permute(0,2,1,3).contiguous() - k,v = kv[0], kv[1] - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x + pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -# from nnFormer -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -# from nnFormer -class PatchMerging(nn.Module): - - - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.reduction = nn.Conv3d(dim,dim*2,kernel_size=3,stride=2,padding=1) - - self.norm = norm_layer(dim) - - def forward(self, x, S, H, W): - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - x = x.view(B, S, H, W, C) - - x = F.gelu(x) - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x=self.reduction(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,2*C) - - return x - -# from nnFormer -class Patch_Expanding(nn.Module): - def __init__(self, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - - self.norm = norm_layer(dim) - self.up=nn.ConvTranspose3d(dim,dim//2,2,2) - def forward(self, x, S, H, W): - - - B, L, C = x.shape - assert L == H * W * S, "input feature has wrong size" - - x = x.view(B, S, H, W, C) - - - - x = self.norm(x) - x=x.permute(0,4,1,2,3).contiguous() - x = self.up(x) - x = ContiguousGrad.apply(x) - x=x.permute(0,2,3,4,1).contiguous().view(B,-1,C//2) - - return x - -# from nnFormer -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - -# from nnFormer -class BasicLayer_up(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - upsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - - - # build blocks - self.blocks = nn.ModuleList() - self.blocks.append( - SwinTransformerBlock_kv( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[0] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - for i in range(depth-1): - self.blocks.append( - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=window_size // 2 , - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i+1] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - ) - - - - self.Upsample = upsample(dim=2*dim, norm_layer=norm_layer) - def forward(self, x,skip, S, H, W): - - - x_up = self.Upsample(x, S, H, W) - - x = x_up + skip - S, H, W = S * 2, H * 2, W * 2 - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) # 3d��3��winds�˻�����Ŀ�Ǻܴ�ģ�����winds����̫�� - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - x = self.blocks[0](x, attn_mask,skip=skip,x_up=x_up) - for i in range(self.depth-1): - x = self.blocks[i+1](x,attn_mask) - - return x, S, H, W - - -# from nnFormer -class project(nn.Module): - def __init__(self,in_dim,out_dim,stride,padding,activate,norm,last=False): - super().__init__() - self.out_dim=out_dim - self.conv1=nn.Conv3d(in_dim,out_dim,kernel_size=3,stride=stride,padding=padding) - self.conv2=nn.Conv3d(out_dim,out_dim,kernel_size=3,stride=1,padding=1) - self.activate=activate() - self.norm1=norm(out_dim) - self.last=last - if not last: - self.norm2=norm(out_dim) - - def forward(self,x): - x=self.conv1(x) - x=self.activate(x) - #norm1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - - - x=self.conv2(x) - if not self.last: - x=self.activate(x) - #norm2 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_dim, Ws, Wh, Ww) - return x - - -# from nnFormer -class PatchEmbed_backup(nn.Module): - def __init__(self, patch_size=4, in_chans=4, embed_dim=96, norm_layer=None): - super().__init__() - patch_size = to_3tuple(patch_size) - self.patch_size = patch_size - - self.in_chans = in_chans - self.embed_dim = embed_dim - stride1=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - stride2=[patch_size[0]//2,patch_size[1]//2,patch_size[2]//2] - self.proj1 = project(in_chans,embed_dim//2,stride1,1,nn.GELU,nn.LayerNorm,False) - self.proj2 = project(embed_dim//2,embed_dim,stride2,1,nn.GELU,nn.LayerNorm,True) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - """Forward function.""" - # padding - _, _, S, H, W = x.size() - if W % self.patch_size[2] != 0: - x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) - if H % self.patch_size[1] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) - if S % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - S % self.patch_size[0])) - x = self.proj1(x) # B C Ws Wh Ww - x = self.proj2(x) # B C Ws Wh Ww - if self.norm is not None: - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm(x) - x = x.transpose(1, 2).contiguous().view(-1, self.embed_dim, Ws, Wh, Ww) - - return x - - -class PatchEmbed(nn.Module): - """ - replace patch embed with conv layers""" - def __init__(self, in_chns=1, ft_chns = [32, 64, 128], dropout = [0, 0, 0.2]): - super().__init__() - self.in_conv= ConvBlock(in_chns, ft_chns[0], dropout[0]) - self.down1 = DownBlock(ft_chns[0], ft_chns[1], dropout[1]) - self.down2 = DownBlock(ft_chns[1], ft_chns[2], dropout[2]) - - - def forward(self, x): - """Forward function.""" - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - return x2 - -# from nnFormer -class Encoder(nn.Module): - - def __init__(self, - pretrain_img_size=224, - patch_size=4, - in_chans=1 , - embed_dim=96, - depths=[2, 2, 2, 2], - num_heads=[4, 8, 16, 32], - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - out_indices=(0, 1, 2, 3) - ): - super().__init__() - - self.pretrain_img_size = pretrain_img_size - - self.num_layers = len(depths) - print("number of layers in encoder", self.num_layers, depths) - self.embed_dim = embed_dim - self.patch_norm = patch_norm - self.out_indices = out_indices - - # split image into non-overlapping patches - # self.patch_embed = PatchEmbed( - # patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, - # norm_layer=norm_layer if self.patch_norm else None) - self.patch_embed = PatchEmbed(in_chans, ft_chns=[embed_dim // 4, embed_dim //2, embed_dim]) - - - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = BasicLayer( - dim=int(embed_dim * 2 ** i_layer), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** i_layer, pretrain_img_size[1] // patch_size[1] // 2 ** i_layer, - pretrain_img_size[2] // patch_size[2] // 2 ** i_layer), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - downsample=PatchMerging - if (i_layer < self.num_layers - 1) else None - ) - self.layers.append(layer) - - num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - self.num_features = num_features - - # add a norm layer for each output - for i_layer in out_indices: - layer = norm_layer(num_features[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - - - def forward(self, x): - """Forward function.""" - - x = self.patch_embed(x) - down=[] - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - - - for i in range(self.num_layers): - layer = self.layers[i] - x_out, S, H, W, x, Ws, Wh, Ww = layer(x, Ws, Wh, Ww) - if i in self.out_indices: - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(x_out) - - out = x_out.view(-1, S, H, W, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() - - down.append(out) - return down - - -# from nnFormer -class Decoder(nn.Module): - def __init__(self, - pretrain_img_size, - embed_dim, - patch_size=4, - depths=[2,2,2], - num_heads=[24,12,6], - window_size=4, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm - ): - super().__init__() - - - self.num_layers = len(depths) - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build layers - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers)[::-1]: - - layer = BasicLayer_up( - dim=int(embed_dim * 2 ** (len(depths)-i_layer-1)), - input_resolution=( - pretrain_img_size[0] // patch_size[0] // 2 ** (len(depths)-i_layer-1), pretrain_img_size[1] // patch_size[1] // 2 ** (len(depths)-i_layer-1), - pretrain_img_size[2] // patch_size[2] // 2 ** (len(depths)-i_layer-1)), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size[i_layer], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[sum( - depths[:i_layer]):sum(depths[:i_layer + 1])], - norm_layer=norm_layer, - upsample=Patch_Expanding - ) - self.layers.append(layer) - self.num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] - def forward(self,x,skips): - - outs=[] - S, H, W = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - for index,i in enumerate(skips): - i = i.flatten(2).transpose(1, 2).contiguous() - skips[index]=i - x = self.pos_drop(x) - - for i in range(self.num_layers)[::-1]: - - layer = self.layers[i] - - x, S, H, W, = layer(x,skips[i], S, H, W) - out = x.view(-1, S, H, W, self.num_features[i]) - outs.append(out) - return outs - - -class final_patch_expanding(nn.Module): - def __init__(self,dim,num_class,patch_size): - super().__init__() - self.up=nn.ConvTranspose3d(dim,num_class,patch_size,patch_size) - - def forward(self,x): - x=x.permute(0,4,1,2,3).contiguous() - x=self.up(x) - - - return x - - - - - -class HiFormer_v1(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v1, self).__init__() - # crop_size=[96,96,96], - # embedding_dim=192, - # input_channels=1, - # num_classes=9, - # conv_op=nn.Conv3d, - # depths=[2,2,2,2], - # num_heads=[6, 12, 24, 48], - # patch_size=[4,4,4], - # window_size=[4,4,8,4], - # deep_supervision=False): - - crop_size = params["input_size"] - embed_dim = params.get("embedding_dim", 192) - input_channels = params["in_chns"] - num_classes = params["class_num"] - self.conv_op = nn.Conv3d - depths = params.get("depths", [2, 2, 2, 2]) - num_heads = params.get("num_heads", [6, 12, 24, 48]) - patch_size = params.get("patch_size", [4, 4, 4]) # for patch embedding - window_size = params.get("window_size", [4, 4, 8, 4]) # for swin transformer window - self._deep_supervision = params.get("deep_supervision", False) - self.do_ds = params.get("deep_supervision", False) - - - self.num_classes = num_classes - self.upscale_logits_ops = [] - self.upscale_logits_ops.append(lambda x: x) - - self.model_down=Encoder(pretrain_img_size=crop_size,window_size=window_size,embed_dim=embed_dim, - patch_size=patch_size,depths=depths,num_heads=num_heads,in_chans=input_channels, out_indices=range(len(depths))) - self.decoder=Decoder(pretrain_img_size=crop_size,embed_dim=embed_dim,window_size=window_size[::-1][1:],patch_size=patch_size,num_heads=num_heads[::-1][:-1],depths=depths[::-1][1:]) - - self.final=[] - if self.do_ds: - - for i in range(len(depths)-1): - self.final.append(final_patch_expanding(embed_dim*2**i,num_classes,patch_size=patch_size)) - - else: - self.final.append(final_patch_expanding(embed_dim,num_classes,patch_size=patch_size)) - - self.final=nn.ModuleList(self.final) - - - def forward(self, x): - - - seg_outputs=[] - skips = self.model_down(x) - neck=skips[-1] - - out=self.decoder(neck,skips) - - - - if self.do_ds: - for i in range(len(out)): - seg_outputs.append(self.final[-(i+1)](out[i])) - - - return seg_outputs[::-1] - else: - seg_outputs.append(self.final[0](out[-1])) - return seg_outputs[-1] - - -if __name__ == "__main__": - # params = {"input_size": [96, 96, 96], - # "in_chns": 1, - # "depth": [2, 2, 2, 2], - # "num_heads": [6, 12, 24, 48], - # "window_size": [6, 6, 6, 3], - # "class_num": 5} - params = {"input_size": [96, 96, 96], - "in_chns": 1, - "depths": [2, 2, 2], - "num_heads": [6, 12, 24], - "window_size": [6, 6, 6], - "class_num": 5} - Net = HiFormer_v1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v2.py b/pymic/net/net3d/trans3d/HiFormer_v2.py deleted file mode 100644 index 7d4c440..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v2.py +++ /dev/null @@ -1,381 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v2(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v2, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [32, 128, 128], - "in_chns": 1, - "down_dims": [2, 2, 3, 3], - "conv_dims": [2, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v2(params) - Net = Net.double() - - x = np.random.rand(1, 1, 32, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v3.py b/pymic/net/net3d/trans3d/HiFormer_v3.py deleted file mode 100644 index 2f8c831..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v3.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, dim = 2, first_layer = False): - super(DownSample, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - stride = [1, 2, 2] - padding = [0, 1, 1] - else: - kernel_size = 3 - stride = 2 - padding = 1 - - if(first_layer): - self.down = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride) - else: - self.down = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, - padding=padding, stride = stride), - ) - - def forward(self, x): - return self.down(x) - - - -class ConvTransBlock_backup(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - self.trans = BasicLayer( - dim= chns, - input_resolution= input_resolution, - depth=depth, - num_heads=num_head, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=drop_path_rate, - norm_layer=norm_layer, - downsample= None - ) - self.norm_layer = nn.LayerNorm(chns) - self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.pos_drop(x) - x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # x2 = self.norm_layer(x2) - x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - self.up = nn.ConvTranspose3d(chns_h, chns_l, - kernel_size = kernel_size, stride=stride) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - ): - super().__init__() - self.high_res = high_res - - self.down1 = DownSample(in_chns, ft_chns[0], down_dims[0], first_layer=True) - self.down2 = DownSample(ft_chns[0], ft_chns[1], down_dims[1]) - self.down3 = DownSample(ft_chns[1], ft_chns[2], down_dims[2]) - self.down4 = DownSample(ft_chns[2], ft_chns[3], down_dims[3]) - - if(high_res): - self.conv0 = ConvBlock(in_chns, ft_chns[0] // 2, 3, 0) - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - - - def forward(self, x): - """Forward function.""" - if(self.high_res): - x0 = self.conv0(x) - x1 = self.conv1(self.down1(x)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - if(self.high_res): - return x0, x1, x2, x3, x4 - else: - return x1, x2, x3, x4 - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [48, 192, 384, 768], - input_size = [32, 128, 128], - down_dims = [2, 2, 3, 3], - conv_dims = [2, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - high_res = False, - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - self.high_res = high_res - if(self.high_res): - self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[1], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[2], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[3], conv_dims[2]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv1 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv2 = ConvTransBlock(chns = ft_chns[1], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[1], - attn_drop_rate=dropout[1] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - - kernel_size, stride = 2, 2 - if down_dims[0] == 2: - kernel_size, stride = [1, 2, 2], [1, 2, 2] - if(self.high_res): - self.out_conv0 = nn.Conv3d(ft_chns[0] // 2, class_num, - kernel_size = [1, 3, 3], padding = [0, 1, 1]) - else: - self.out_conv0 = nn.ConvTranspose3d(ft_chns[0], class_num, - kernel_size = kernel_size, stride= stride) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(ft_chns[0], class_num, kernel_size = 1) - self.out_conv2 = nn.Conv3d(ft_chns[1], class_num, kernel_size = 1) - self.out_conv3 = nn.Conv3d(ft_chns[2], class_num, kernel_size = 1) - - def forward(self, x): - if(self.high_res): - x0, x1, x2, x3, x4 = x - else: - x1, x2, x3, x4 = x - x_d3 = self.conv3(self.up3(x3, x4)) - x_d2 = self.conv2(self.up2(x2, x_d3)) - x_d1 = self.conv1(self.up1(x1, x_d2)) - if(self.high_res): - x_d0 = self.conv0(self.up0(x0, x_d1)) - output = self.out_conv0(x_d0) - else: - output = self.out_conv0(x_d1) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v3(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v3, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [48, 192, 384, 764]) - down_dims = params.get("down_dims", [2, 2, 3, 3]) - conv_dims = params.get("conv_dims", [2, 3, 3, 3]) - dropout = params.get('dropout', [0, 0.2, 0.2, 0.2]) - high_res = params.get("high_res", False) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - high_res = high_res, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3], - "feature_chns": [96, 192, 384, 768], - "high_res": True, - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v3(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v4.py b/pymic/net/net3d/trans3d/HiFormer_v4.py deleted file mode 100644 index f0c6087..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v4.py +++ /dev/null @@ -1,455 +0,0 @@ - -import torch -import numpy as np -import torch.utils.checkpoint as checkpoint -from einops import rearrange -from copy import deepcopy -from torch import nn -from pymic.net.net3d.trans3d.HiFormer_v1 import BasicLayer - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dim = 2, dropout_p = 0.0): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.PReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - - -class DownSample(nn.Module): - def __init__(self, in_channels, out_channels, down_dim = 3, conv_dim = 3): - super(DownSample, self).__init__() - assert(down_dim == 2 or down_dim == 3) - assert(conv_dim == 2 or conv_dim == 3) - - kernel_size = [1, 2, 2] if(down_dim == 2) else 2 - self.pool = nn.MaxPool3d(kernel_size) - - if(conv_dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv(self.pool(x)) - - - -# class ConvTransBlock(nn.Module): -# def __init__(self, -# input_resolution= [32, 32, 32], -# chns=96, -# depth=2, -# num_head=4, -# window_size=7, -# mlp_ratio=4., -# qkv_bias=True, -# qk_scale=None, -# drop_rate=0., -# attn_drop_rate=0., -# drop_path_rate=0.2, -# norm_layer=nn.LayerNorm, -# patch_norm=True, -# ): -# super().__init__() -# self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) -# self.trans = BasicLayer( -# dim= chns, -# input_resolution= input_resolution, -# depth=depth, -# num_heads=num_head, -# window_size=window_size, -# mlp_ratio=mlp_ratio, -# qkv_bias=qkv_bias, -# qk_scale=qk_scale, -# drop=drop_rate, -# attn_drop=attn_drop_rate, -# drop_path=drop_path_rate, -# norm_layer=norm_layer, -# downsample= None -# ) -# self.norm_layer = nn.LayerNorm(chns) -# self.pos_drop = nn.Dropout(p=drop_rate) - -# def forward(self, x): -# """Forward function.""" -# x1 = self.conv(x) -# C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) -# x = x.flatten(2).transpose(1, 2).contiguous() -# x = self.pos_drop(x) -# x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) -# # x2 = self.norm_layer(x2) -# x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() -# return x1 + x2 - -# only using the conv block -class ConvTransBlock(nn.Module): - def __init__(self, - input_resolution= [32, 32, 32], - chns=96, - depth=2, - num_head=4, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop_rate=0., - attn_drop_rate=0., - drop_path_rate=0.2, - norm_layer=nn.LayerNorm, - patch_norm=True, - ): - super().__init__() - self.conv = ConvBlock(chns, chns, dim = 3, dropout_p = drop_rate) - # self.trans = BasicLayer( - # dim= chns, - # input_resolution= input_resolution, - # depth=depth, - # num_heads=num_head, - # window_size=window_size, - # mlp_ratio=mlp_ratio, - # qkv_bias=qkv_bias, - # qk_scale=qk_scale, - # drop=drop_rate, - # attn_drop=attn_drop_rate, - # drop_path=drop_path_rate, - # norm_layer=norm_layer, - # downsample= None - # ) - # self.norm_layer = nn.LayerNorm(chns) - # self.pos_drop = nn.Dropout(p=drop_rate) - - def forward(self, x): - """Forward function.""" - x1 = self.conv(x) - return x1 - # C, Ws, Wh, Ww = x.size(1), x.size(2), x.size(3), x.size(4) - # x = x.flatten(2).transpose(1, 2).contiguous() - # x = self.pos_drop(x) - # x2, S, H, W, x, Ws, Wh, Ww = self.trans(x, Ws, Wh, Ww) - # # x2 = self.norm_layer(x2) - # x2 = x2.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - # return x1 + x2 - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.PReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class UpCatBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, chns_l, chns_h, up_dim = 3, conv_dim = 3): - super(UpCatBlock, self).__init__() - assert(up_dim == 2 or up_dim == 3) - if(up_dim == 2): - kernel_size, stride = [1, 2, 2], [1, 2, 2] - else: - kernel_size, stride = 2, 2 - - self.up = nn.Sequential( - nn.BatchNorm3d(chns_h), - nn.PReLU(), - nn.ConvTranspose3d(chns_h, chns_l, kernel_size = kernel_size, stride=stride) - ) - - if(conv_dim == 2): - kernel_size, padding = [1, 3, 3], [0, 1, 1] - else: - kernel_size, padding = 3, 1 - self.conv = nn.Sequential( - nn.BatchNorm3d(chns_l*2), - nn.PReLU(), - nn.Conv3d(chns_l*2, chns_l, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x_l, x_h): - # print("input shapes", x1.shape, x2.shape) - # print("after upsample", x1.shape) - y = torch.cat([x_l, self.up(x_h)], dim=1) - return self.conv(y) - -class Encoder(nn.Module): - def __init__(self, - in_chns = 1 , - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - ): - super().__init__() - self.proj = nn.Conv3d(in_chns, ft_chns[0], kernel_size=3, padding=1) - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvBlock(ft_chns[2], ft_chns[2], conv_dims[2], dropout[2]) - - self.down1 = DownSample(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[1]) - self.down2 = DownSample(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[2]) - self.down3 = DownSample(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[3]) - self.down4 = DownSample(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[4]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - r_t4 = [r_t3[i] // down_scales[3][i] for i in range(3)] - - self.conv_t2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv_t3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - self.conv_t4 = ConvTransBlock(chns = ft_chns[4], - input_resolution = r_t4, - window_size = window_sizes[2], - depth = depths[2], - num_head = num_heads[2], - drop_rate = dropout[4], - attn_drop_rate=dropout[4] - ) - - - - def forward(self, x): - """Forward function.""" - x0 = self.conv0(self.proj(x)) - x1 = self.conv1(self.down1(x0)) - x2 = self.conv2(self.down2(x1)) - x2 = self.conv_t2(x2) - x3 = self.conv_t3(self.down3(x2)) - x4 = self.conv_t4(self.down4(x3)) - return x0, x1, x2, x3, x4 - - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, - ft_chns = [24, 48, 192, 384, 768], - input_size= [32, 128, 128], - down_dims = [3, 3, 3, 3, 3], - conv_dims = [3, 3, 3, 3, 3], - dropout = [0, 0, 0.2, 0.2, 0.2], - depths = [2, 2, 2], - num_heads = [4, 8, 16], - window_sizes = [6, 6, 6], - class_num = 2, - multiscale_pred = False - ): - super(Decoder, self).__init__() - # self.up0 = UpCatBlock(ft_chns[0] // 2, ft_chns[0], down_dims[0], 3) - # self.conv0 = ConvBlock(ft_chns[0] // 2, ft_chns[0] // 2, 3, 0) - self.up1 = UpCatBlock(ft_chns[0], ft_chns[1], down_dims[0], conv_dims[0]) - self.up2 = UpCatBlock(ft_chns[1], ft_chns[2], down_dims[1], conv_dims[1]) - self.up3 = UpCatBlock(ft_chns[2], ft_chns[3], down_dims[2], conv_dims[2]) - self.up4 = UpCatBlock(ft_chns[3], ft_chns[4], down_dims[3], conv_dims[3]) - - down_scales = [] - for i in range(4): - down_scale = [1, 2, 2] if down_dims[i] == 2 else [2, 2, 2] - down_scales.append(down_scale) - - r_t2 = [input_size[i] // down_scales[0][i] // down_scales[1][i] for i in range(3)] - r_t3 = [r_t2[i] // down_scales[2][i] for i in range(3)] - - self.conv0 = ConvBlock(ft_chns[0], ft_chns[0], conv_dims[0], dropout[0]) - self.conv1 = ConvBlock(ft_chns[1], ft_chns[1], conv_dims[1], dropout[1]) - self.conv2 = ConvTransBlock(chns = ft_chns[2], - input_resolution = r_t2, - window_size = window_sizes[0], - depth = depths[0], - num_head = num_heads[0], - drop_rate = dropout[2], - attn_drop_rate=dropout[2] - ) - self.conv3 = ConvTransBlock(chns = ft_chns[3], - input_resolution = r_t3, - window_size = window_sizes[1], - depth = depths[1], - num_head = num_heads[1], - drop_rate = dropout[3], - attn_drop_rate=dropout[3] - ) - - self.out_conv0 = ConvLayer(ft_chns[0], class_num) - - self.mul_pred = multiscale_pred - if(self.mul_pred): - self.out_conv1 = ConvLayer(ft_chns[1], class_num) - self.out_conv2 = ConvLayer(ft_chns[2], class_num) - self.out_conv3 = ConvLayer(ft_chns[3], class_num) - - def forward(self, x): - x0, x1, x2, x3, x4 = x - - x_d3 = self.conv3(self.up4(x3, x4)) - x_d2 = self.conv2(self.up3(x2, x_d3)) - x_d1 = self.conv1(self.up2(x1, x_d2)) - x_d0 = self.conv0(self.up1(x0, x_d1)) - output = self.out_conv0(x_d0) - - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v4(nn.Module): - def __init__(self, params): - """ - replace the embedding layer with convolutional blocks - """ - super(HiFormer_v4, self).__init__() - in_chns = params["in_chns"] - class_num = params["class_num"] - input_size = params["input_size"] - ft_chns = params.get("feature_chns", [32, 64, 128, 256, 512]) - down_dims = params.get("down_dims", [3, 3, 3, 3, 3]) - conv_dims = params.get("conv_dims", [3, 3, 3, 3, 3]) - dropout = params.get('dropout', [0, 0, 0.2, 0.2, 0.2]) - depths = params.get("depths", [2, 2, 2]) - num_heads = params.get("num_heads", [4, 8, 16]) - window_sizes= params.get("window_sizes", [6, 6, 6]) - multiscale_pred = params.get("multiscale_pred", False) - - self.encoder = Encoder(in_chns, - ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes) - - self.decoder = Decoder(ft_chns = ft_chns, - input_size = input_size, - down_dims = down_dims, - conv_dims = conv_dims, - dropout = dropout, - depths = depths, - num_heads = num_heads, - window_sizes= window_sizes, - class_num = class_num, - multiscale_pred = multiscale_pred - ) - - def forward(self, x): - x = self.encoder(x) - x = self.decoder(x) - return x - - -if __name__ == "__main__": - params = {"input_size": [64, 96, 96], - "in_chns": 1, - "down_dims": [3, 3, 3, 3, 3], - "conv_dims": [3, 3, 3, 3, 3], - "feature_chns": [32, 64, 128, 256, 512], - "class_num": 5, - "multiscale_pred": True} - Net = HiFormer_v4(params) - Net = Net.double() - - x = np.random.rand(1, 1, 64, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - if(params['multiscale_pred']): - for yi in y: - print(yi.shape) - else: - print(y.shape) - - - diff --git a/pymic/net/net3d/trans3d/HiFormer_v5.py b/pymic/net/net3d/trans3d/HiFormer_v5.py deleted file mode 100644 index 5fcef5a..0000000 --- a/pymic/net/net3d/trans3d/HiFormer_v5.py +++ /dev/null @@ -1,308 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate - - -class ConvBlock(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p = 0.0, dim = 3): - super(ConvBlock, self).__init__() - assert(dim == 2 or dim == 3) - if(dim == 2): - kernel_size = [1, 3, 3] - padding = [0, 1, 1] - else: - kernel_size = 3 - padding = 1 - - self.conv_conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=padding), - nn.BatchNorm3d(out_channels), - nn.LeakyReLU(), - nn.Dropout(dropout_p), - nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, padding=padding), - ) - - def forward(self, x): - return self.conv_conv(x) - -class ConvLayer(nn.Module): - """ - 2D or 3D convolutional block - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dim: (int) Should be 2 or 3, for 2D and 3D convolution, respectively. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, kernel = 1, padding = 0): - super(ConvLayer, self).__init__() - - self.conv = nn.Sequential( - nn.BatchNorm3d(in_channels), - nn.LeakyReLU(), - nn.Conv3d(in_channels, out_channels, kernel_size=kernel, padding=padding), - ) - - def forward(self, x): - return self.conv(x) - -class DownBlock(nn.Module): - """ - 3D downsampling followed by ConvBlock - - :param in_channels: (int) Input channel number. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - """ - def __init__(self, in_channels, out_channels, dropout_p): - super(DownBlock, self).__init__() - self.maxpool_conv = nn.Sequential( - nn.MaxPool3d(2), - ConvBlock(in_channels, out_channels, dropout_p) - ) - - def forward(self, x): - return self.maxpool_conv(x) - -class UpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): - super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.Sequential( - nn.BatchNorm3d(in_channels1), - nn.LeakyReLU(), - nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - ) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - output = [x0, x1, x2, x3] - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - output.append(x4) - return output - -class Decoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(Decoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params.get('multiscale_pred', False) - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class HiFormer_v5(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. - """ - def __init__(self, params): - super(HiFormer_v5, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.proj = nn.Conv3d(self.in_chns, self.ft_chns[0], kernel_size=3, padding=1) - self.in_conv= ConvBlock(self.ft_chns[0], self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = ConvLayer(self.ft_chns[0], self.n_class) - if(self.mul_pred): - self.out_conv1 = ConvLayer(self.ft_chns[1], self.n_class) - self.out_conv2 = ConvLayer(self.ft_chns[2], self.n_class) - self.out_conv3 = ConvLayer(self.ft_chns[3], self.n_class) - - def forward(self, x): - x0 = self.in_conv(self.proj(x)) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[32, 64, 128, 256, 512], - 'dropout' : [0, 0, 0, 0, 0.5], - 'trilinear': False, - 'multiscale_pred': False} - Net = HiFormer_v5(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v1.py b/pymic/net/net3d/trans3d/MedFormer_v1.py deleted file mode 100644 index 1f2ed54..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v1.py +++ /dev/null @@ -1,173 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Encoder, Decoder - -class Attention(nn.Module): - def __init__(self, params): - super(Attention, self).__init__() - hidden_size = params["attention_hidden_size"] - self.num_attention_heads = params["attention_num_heads"] - self.attention_head_size = int(hidden_size / self.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = Linear(hidden_size, self.all_head_size) - self.key = Linear(hidden_size, self.all_head_size) - self.value = Linear(hidden_size, self.all_head_size) - - self.out = Linear(hidden_size, hidden_size) - self.attn_dropout = Dropout(params["attention_dropout_rate"]) - self.proj_dropout = Dropout(params["attention_dropout_rate"]) - - self.softmax = Softmax(dim=-1) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - attention_probs = self.softmax(attention_scores) - # weights = attention_probs if self.vis else None - attention_probs = self.attn_dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - attention_output = self.out(context_layer) - attention_output = self.proj_dropout(attention_output) - return attention_output - -class MLP(nn.Module): - def __init__(self, params): - super(MLP, self).__init__() - hidden_size = params["attention_hidden_size"] - mlp_dim = params["attention_mlp_dim"] - self.fc1 = Linear(hidden_size, mlp_dim) - self.fc2 = Linear(mlp_dim, hidden_size) - self.act_fn = torch.nn.functional.gelu - self.dropout = Dropout(params["attention_dropout_rate"]) - - self._init_weights() - - def _init_weights(self): - nn.init.xavier_uniform_(self.fc1.weight) - nn.init.xavier_uniform_(self.fc2.weight) - nn.init.normal_(self.fc1.bias, std=1e-6) - nn.init.normal_(self.fc2.bias, std=1e-6) - - def forward(self, x): - x = self.fc1(x) - x = self.act_fn(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class Block(nn.Module): - def __init__(self, params): - super(Block, self).__init__() - hidden_size = params["attention_hidden_size"] - self.attention_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn_norm = LayerNorm(hidden_size, eps=1e-6) - self.ffn = MLP(params) - self.attn = Attention(params) - - def forward(self, x): - # convert the tensor shape from [B, C, D, H, W] to [B, DHW, C] - [B, C, D, H, W] = list(x.shape) - new_shape = [B, C, D*H*W] - x = torch.reshape(x, new_shape) - x = torch.transpose(x, 1, 2) - - h = x - x = self.attention_norm(x) - x = self.attn(x) - x = x + h - - h = x - x = self.ffn_norm(x) - x = self.ffn(x) - x = x + h - - # convert the result back to [B, C, D, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, [B, C, D, H, W]) - return x - -class MedFormerV1(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param deep_supervise: (bool) Using deep supervision for training or not. - """ - def __init__(self, params): - super(MedFormerV1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerV1(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v2.py b/pymic/net/net3d/trans3d/MedFormer_v2.py deleted file mode 100644 index 00cb295..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v2.py +++ /dev/null @@ -1,464 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import copy -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from pymic.net.net3d.unet3d import ConvBlock, Encoder, Decoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from timm.models.layers import DropPath, to_3tuple, trunc_normal_ - - -# code from nnFormer -class Mlp(nn.Module): - """ Multilayer perceptron.""" - - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - - B, S, H, W, C = x.shape - x = x.view(B, S // window_size, window_size, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, S, H, W): - - B = int(windows.shape[0] / (S * H * W / window_size / window_size / window_size)) - x = windows.view(B, S // window_size, H // window_size, W // window_size, window_size, window_size, window_size, -1) - x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, S, H, W, -1) - return x - - -class WindowAttention(nn.Module): - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), - num_heads)) - - # get pair-wise relative position index for each token inside the window - coords_s = torch.arange(self.window_size[0]) - coords_h = torch.arange(self.window_size[1]) - coords_w = torch.arange(self.window_size[2]) - coords = torch.stack(torch.meshgrid([coords_s, coords_h, coords_w])) - coords_flatten = torch.flatten(coords, 1) - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).contiguous() - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 2] += self.window_size[2] - 1 - - relative_coords[:, :, 0] *= 3 * self.window_size[1] - 1 - relative_coords[:, :, 1] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - trunc_normal_(self.relative_position_bias_table, std=.02) - - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None,pos_embed=None): - - B_, N, C = x.shape - - qkv = self.qkv(x) - - qkv=qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1).contiguous()) - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1] * self.window_size[2], - self.window_size[0] * self.window_size[1] * self.window_size[2], -1) - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() - if pos_embed is not None: - x = x+pos_embed - x = self.proj(x) - x = self.proj_drop(x) - return x - -class SwinTransformerBlock(nn.Module): - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - - self.attn = WindowAttention( - dim, window_size=to_3tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - - def forward(self, x, mask_matrix): - - B, L, C = x.shape - S, H, W = self.input_resolution - - assert L == S * H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, S, H, W, C) - - # pad feature maps to multiples of window size - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size - pad_g = (self.window_size - S % self.window_size) % self.window_size - - x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b, 0, pad_g)) - _, Sp, Hp, Wp, _ = x.shape - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size,-self.shift_size), dims=(1, 2,3)) - attn_mask = mask_matrix - else: - shifted_x = x - attn_mask = None - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size * self.window_size, - C) - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=attn_mask,pos_embed=None) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, Sp, Hp, Wp) - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size, self.shift_size), dims=(1, 2, 3)) - else: - x = shifted_x - - if pad_r > 0 or pad_b > 0 or pad_g > 0: - x = x[:, :S, :H, :W, :].contiguous() - - x = x.view(B, S * H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - -class BasicLayer(nn.Module): - - def __init__(self, - dim, - input_resolution, - depth, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm, - downsample=True - ): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - self.depth = depth - # build blocks - - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, S, H, W): - - - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - for blk in self.blocks: - - x = blk(x, attn_mask) - if self.downsample is not None: - x_down = self.downsample(x, S, H, W) - Ws, Wh, Ww = (S + 1) // 2, (H + 1) // 2, (W + 1) // 2 - return x, S, H, W, x_down, Ws, Wh, Ww - else: - return x, S, H, W, x, S, H, W - - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - depth = att_params['depth'] - num_heads = att_params['num_heads'] - self.attn = BasicLayer(out_channels, input_resolution, depth, num_heads, downsample=None) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - [B, C, D, H, W] = list(x.shape) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.attn(x, D, H, W)[0] - x = x.view(-1, D, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - return x - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "depth": 2, "num_heads": 4} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "depth": 2, "num_heads": 4} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV2(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV2, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV2(params) - Net = Net.double() - - x = np.random.rand(1, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_v3.py b/pymic/net/net3d/trans3d/MedFormer_v3.py deleted file mode 100644 index f119a9c..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_v3.py +++ /dev/null @@ -1,255 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from torch.nn.functional import interpolate -from pymic.net.net3d.unet3d import ConvBlock, Encoder -from pymic.net.net3d.trans3d.MedFormer_v1 import Block -from pymic.net.net3d.trans3d.MedFormer_v2 import SwinTransformerBlock, window_partition - -class GLAttLayer(nn.Module): - def __init__(self, - dim, - input_resolution, - num_heads, - window_size=7, - mlp_ratio=4., - qkv_bias=True, - qk_scale=None, - drop=0., - attn_drop=0., - drop_path=0., - norm_layer=nn.LayerNorm): - super().__init__() - self.window_size = window_size - self.shift_size = window_size // 2 - # build blocks - - self.lcl_att = SwinTransformerBlock( - dim=dim, - input_resolution=input_resolution, - num_heads=num_heads, - window_size=window_size, - shift_size=0, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop, - attn_drop=attn_drop, - drop_path=drop_path) - self.adpool = nn.AdaptiveAvgPool3d([12, 12, 12]) - - params = {'attention_hidden_size': dim, - 'attention_num_heads': 4, - 'attention_mlp_dim': dim, - 'attention_dropout_rate': 0.2} - self.glb_att = Block(params) - self.conv1x1 = nn.Sequential( - nn.Conv3d(2*dim, dim, kernel_size=1), - nn.BatchNorm3d(dim), - nn.LeakyReLU()) - - def forward(self, x): - [B, C, S, H, W] = list(x.shape) - # calculate attention mask for SW-MSA - Sp = int(np.ceil(S / self.window_size)) * self.window_size - Hp = int(np.ceil(H / self.window_size)) * self.window_size - Wp = int(np.ceil(W / self.window_size)) * self.window_size - img_mask = torch.zeros((1, Sp, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 - s_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for s in s_slices: - for h in h_slices: - for w in w_slices: - img_mask[:, s, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) - mask_windows = mask_windows.view(-1, - self.window_size * self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - # for local attention - xl = x.flatten(2).transpose(1, 2).contiguous() - xl = self.lcl_att(xl, attn_mask) - xl = xl.view(-1, S, H, W, C).permute(0, 4, 1, 2, 3).contiguous() - - # for global attention - xg = self.adpool(x) - xg = self.glb_att(xg) - xg = interpolate(xg, [S, H, W], mode = 'trilinear') - out = torch.cat([xl, xg], dim=1) - out = self.conv1x1(out) - return out - -class AttUpBlock(nn.Module): - """ - 3D upsampling followed by ConvBlock - - :param in_channels1: (int) Channel number of high-level features. - :param in_channels2: (int) Channel number of low-level features. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. - """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True, with_att = False, att_params = None): - super(AttUpBlock, self).__init__() - self.trilinear = trilinear - self.with_att = with_att - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) - else: - self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) - self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) - if(self.with_att): - input_resolution = att_params['input_resolution'] - num_heads = att_params['num_heads'] - window_size = att_params['window_size'] - self.attn = GLAttLayer(out_channels, input_resolution, num_heads, window_size, 2.0) - - def forward(self, x1, x2): - if self.trilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - x = self.conv(x) - if(self.with_att): - x = self.attn(x) - return x - - -class AttDecoder(nn.Module): - """ - Decoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(AttDecoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) - self.mul_pred = self.params['multiscale_pred'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = AttUpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - att_params = {"input_resolution": [24, 24, 24], "num_heads": 4, "window_size": 7} - self.up2 = AttUpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear, True, att_params) - att_params = {"input_resolution": [48, 48, 48], "num_heads": 4, "window_size": 7} - self.up3 = AttUpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear, True, att_params) - self.up4 = AttUpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - return output - -class MedFormerV3(nn.Module): - """ - An implementation of the U-Net. - - * Reference: Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: - 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. - `MICCAI (2) 2016: 424-432. `_ - - Note that there are some modifications from the original paper, such as - the use of batch normalization, dropout, leaky relu and deep supervision. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(MedFormerV3, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = AttDecoder(params) - params["attention_hidden_size"] = params['feature_chns'][-1] - params["attention_mlp_dim"] = params['feature_chns'][-1] - self.attn = Block(params) - - def forward(self, x): - f = self.encoder(x) - f[-1] = self.attn(f[-1]) - output = self.decoder(f) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': True, - 'attention_num_heads': 4, - 'attention_dropout_rate': 0.2} - - Net = MedFormerV3(params) - Net = Net.double() - - x = np.random.rand(2, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) diff --git a/pymic/net/net3d/trans3d/MedFormer_va1.py b/pymic/net/net3d/trans3d/MedFormer_va1.py deleted file mode 100644 index 27dfa3e..0000000 --- a/pymic/net/net3d/trans3d/MedFormer_va1.py +++ /dev/null @@ -1,105 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm -from pymic.net.net3d.unet3d import Decoder - -class EmbeddingBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, padding, stride): - super(EmbeddingBlock, self).__init__() - self.out_channels = out_channels - self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=kernel_size, padding=padding, stride = stride) - self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=1) - self.act = nn.GELU() - self.norm1 = nn.LayerNorm(out_channels//2) - self.norm2 = nn.LayerNorm(out_channels) - - - def forward(self, x): - x = self.act(self.conv1(x)) - # norm 1 - Ws, Wh, Ww = x.size(2), x.size(3), x.size(4) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm1(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels // 2, Ws, Wh, Ww) - - x = self.act(self.conv2(x)) - x = x.flatten(2).transpose(1, 2).contiguous() - x = self.norm2(x) - x = x.transpose(1, 2).contiguous().view(-1, self.out_channels, Ws, Wh, Ww) - - return x - -class Encoder(nn.Module): - """ - Encoder of 3D UNet. - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 4 or 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - """ - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - assert(len(self.ft_chns) == 4) - - self.down0 = EmbeddingBlock(self.in_chns, self.ft_chns[0], 3, 1, 1) - self.down1 = EmbeddingBlock(self.in_chns, self.ft_chns[1], 2, 0, 2) - self.down2 = EmbeddingBlock(self.in_chns, self.ft_chns[2], 4, 0, 4) - self.down3 = EmbeddingBlock(self.in_chns, self.ft_chns[3], 8, 0, 8) - - def forward(self, x): - x0 = self.down0(x) - x1 = self.down1(x) - x2 = self.down2(x) - x3 = self.down3(x) - output = [x0, x1, x2, x3] - return output - -class MedFormerVA1(nn.Module): - def __init__(self, params): - super(MedFormerVA1, self).__init__() - self.params = params - self.encoder = Encoder(params) - self.decoder = Decoder(params) - - def forward(self, x): - f = self.encoder(x) - output = self.decoder(f) - return output - - -if __name__ == "__main__": - params = {'in_chns':1, - 'class_num': 8, - 'feature_chns':[16, 32, 64, 128], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'deep_supervise': True, - 'attention_hidden_size': 128, - 'attention_num_heads': 4, - 'attention_mlp_dim': 256, - 'attention_dropout_rate': 0.2} - Net = MedFormerVA1(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print("output length", len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/__init__.py b/pymic/net/net3d/trans3d/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pymic/net/net3d/trans3d/nnFormer_wrap.py b/pymic/net/net3d/trans3d/nnFormer_wrap.py deleted file mode 100644 index 35593a4..0000000 --- a/pymic/net/net3d/trans3d/nnFormer_wrap.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import math -import torch -import torch.nn as nn -import numpy as np -from nnformer.network_architecture.nnFormer_tumor import nnFormer - -class nnFormer_wrap(nn.Module): - def __init__(self, params): - super(nnFormer_wrap, self).__init__() - patch_size = params["patch_size"] # 96x96x96 - n_class = params['class_num'] - in_chns = params['in_chns'] - # https://github.com/282857341/nnFormer/blob/main/nnformer/network_architecture/nnFormer_tumor.py - self.nnformer = nnFormer(crop_size = patch_size, - embedding_dim=192, - input_channels = in_chns, - num_classes = n_class, - conv_op=nn.Conv3d, - depths =[2,2,2,2], - num_heads = [6, 12, 24, 48], - patch_size = [4,4,4], - window_size= [4,4,8,4], - deep_supervision=False) - - def forward(self, x): - return self.nnformer(x) - -if __name__ == "__main__": - params = {"patch_size": [96, 96, 96], - "in_chns": 1, - "class_num": 5} - Net = nnFormer_wrap(params) - Net = Net.double() - - x = np.random.rand(1, 1, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(y.shape) diff --git a/pymic/net/net3d/trans3d/unetr.py b/pymic/net/net3d/trans3d/unetr.py deleted file mode 100644 index ea90b2f..0000000 --- a/pymic/net/net3d/trans3d/unetr.py +++ /dev/null @@ -1,227 +0,0 @@ -from __future__ import print_function, division - -import torch -import torch.nn as nn - -from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock -from monai.networks.blocks.dynunet_block import UnetOutBlock -from monai.networks.nets import ViT - - -class UNETR(nn.Module): - """ - UNETR based on: "Hatamizadeh et al., - UNETR: Transformers for 3D Medical Image Segmentation " - """ - - def __init__(self, params): - # in_channels: int, - # out_channels: int, - # img_size: Tuple[int, int, int], - # feature_size: int = 16, - # hidden_size: int = 768, - # mlp_dim: int = 3072, - # num_heads: int = 12, - # pos_embed: str = "perceptron", - # norm_name: Union[Tuple, str] = "instance", - # conv_block: bool = False, - # res_block: bool = True, - # dropout_rate: float = 0.0, - # ) -> None: - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of hidden layer. - mlp_dim: dimension of feedforward layer. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - conv_block: bool argument to determine if convolutional block is used. - res_block: bool argument to determine if residual block is used. - dropout_rate: faction of the input units to drop. - Examples:: - # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm - >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') - # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm - >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') - """ - - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - feature_size = 16 - hidden_size = 768 - mlp_dim = 3072 - num_heads = 12 - pos_embed = "perceptron" - norm_name = "instance" - conv_block = False - res_block = True - dropout_rate = 0.0 - - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - raise AssertionError("hidden size should be divisible by num_heads.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - self.num_layers = 12 - self.patch_size = (16, 16, 16) - self.feat_size = ( - img_size[0] // self.patch_size[0], - img_size[1] // self.patch_size[1], - img_size[2] // self.patch_size[2], - ) - self.hidden_size = hidden_size - self.classification = False - self.vit = ViT( - in_channels=in_channels, - img_size=img_size, - patch_size=self.patch_size, - hidden_size=hidden_size, - mlp_dim=mlp_dim, - num_layers=self.num_layers, - num_heads=num_heads, - pos_embed=pos_embed, - classification=self.classification, - dropout_rate=dropout_rate, - ) - self.encoder1 = UnetrBasicBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - res_block=res_block, - ) - self.encoder2 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 2, - num_layer=2, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder3 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 4, - num_layer=1, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.encoder4 = UnetrPrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - num_layer=0, - kernel_size=3, - stride=1, - upsample_kernel_size=2, - norm_name=norm_name, - conv_block=conv_block, - res_block=res_block, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=hidden_size, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - res_block=res_block, - ) - self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def load_from(self, weights): - with torch.no_grad(): - res_weight = weights - # copy weights from patch embedding - for i in weights["state_dict"]: - print(i) - self.vit.patch_embedding.position_embeddings.copy_( - weights["state_dict"]["module.transformer.patch_embedding.position_embeddings_3d"] - ) - self.vit.patch_embedding.cls_token.copy_( - weights["state_dict"]["module.transformer.patch_embedding.cls_token"] - ) - self.vit.patch_embedding.patch_embeddings[1].weight.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.weight"] - ) - self.vit.patch_embedding.patch_embeddings[1].bias.copy_( - weights["state_dict"]["module.transformer.patch_embedding.patch_embeddings.1.bias"] - ) - - # copy weights from encoding blocks (default: num of blocks: 12) - for bname, block in self.vit.blocks.named_children(): - print(block) - block.loadFrom(weights, n_block=bname) - # last norm layer of transformer - self.vit.norm.weight.copy_(weights["state_dict"]["module.transformer.norm.weight"]) - self.vit.norm.bias.copy_(weights["state_dict"]["module.transformer.norm.bias"]) - - def forward(self, x_in): - x, hidden_states_out = self.vit(x_in) - enc1 = self.encoder1(x_in) - x2 = hidden_states_out[3] - enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) - x3 = hidden_states_out[6] - enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) - x4 = hidden_states_out[9] - enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) - dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc4) - dec2 = self.decoder4(dec3, enc3) - dec1 = self.decoder3(dec2, enc2) - out = self.decoder2(dec1, enc1) - logits = self.out(out) - return logits - diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py deleted file mode 100644 index a4ab7e6..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ /dev/null @@ -1,469 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Sequence, Tuple, Union -from pymic.net.net3d.trans3d.unetr_pp_block import UnetOutBlock, UnetResBlock, get_conv_layer -from timm.models.layers import trunc_normal_ -from monai.utils import optional_import -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - -einops, _ = optional_import("einops") - -class LayerNorm(nn.Module): - def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape,) - - def forward(self, x): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None] * x + self.bias[:, None, None] - return x - -class EPA(nn.Module): - """ - Efficient Paired Attention Block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, - channel_attn_drop=0.1, spatial_attn_drop=0.1): - super().__init__() - self.num_heads = num_heads - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) - self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) - - # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) - self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) - - # E and F are projection matrices with shared weights used in spatial attention module to project - # keys and values from HWD-dimension to P-dimension - self.E = self.F = nn.Linear(input_size, proj_size) - - self.attn_drop = nn.Dropout(channel_attn_drop) - self.attn_drop_2 = nn.Dropout(spatial_attn_drop) - - self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) - self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) - - def forward(self, x): - B, N, C = x.shape - - qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) - - qkvv = qkvv.permute(2, 0, 3, 1, 4) - - q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] - - q_shared = q_shared.transpose(-2, -1) - k_shared = k_shared.transpose(-2, -1) - v_CA = v_CA.transpose(-2, -1) - v_SA = v_SA.transpose(-2, -1) - - k_shared_projected = self.E(k_shared) - - v_SA_projected = self.F(v_SA) - - q_shared = torch.nn.functional.normalize(q_shared, dim=-1) - k_shared = torch.nn.functional.normalize(k_shared, dim=-1) - - attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature - - attn_CA = attn_CA.softmax(dim=-1) - attn_CA = self.attn_drop(attn_CA) - - x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) - - attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 - - attn_SA = attn_SA.softmax(dim=-1) - attn_SA = self.attn_drop_2(attn_SA) - - x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) - - # Concat fusion - x_SA = self.out_proj(x_SA) - x_CA = self.out_proj2(x_CA) - x = torch.cat((x_SA, x_CA), dim=-1) - return x - - @torch.jit.ignore - def no_weight_decay(self): - return {'temperature', 'temperature2'} - - -class TransformerBlock(nn.Module): - """ - A transformer block, based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - proj_size: int, - num_heads: int, - dropout_rate: float = 0.0, - pos_embed=False, - ) -> None: - """ - Args: - input_size: the size of the input for each stage. - hidden_size: dimension of hidden layer. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of attention heads. - dropout_rate: faction of the input units to drop. - pos_embed: bool argument to determine if positional embedding is used. - - """ - - super().__init__() - - if not (0 <= dropout_rate <= 1): - raise ValueError("dropout_rate should be between 0 and 1.") - - if hidden_size % num_heads != 0: - print("Hidden size is ", hidden_size) - print("Num heads is ", num_heads) - raise ValueError("hidden_size should be divisible by num_heads.") - - self.norm = nn.LayerNorm(hidden_size) - self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) - self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) - self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") - self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) - - self.pos_embed = None - if pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) - - def forward(self, x): - B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) - - if self.pos_embed is not None: - x = x + self.pos_embed - attn = x + self.gamma * self.epa_block(self.norm(x)) - - attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) - attn = self.conv51(attn_skip) - x = attn_skip + self.conv8(attn) - - return x - -class UnetrPPEncoder(nn.Module): - def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, - in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): - super().__init__() - - self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers - stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), - ) - self.downsample_layers.append(stem_layer) - for i in range(3): - downsample_layer = nn.Sequential( - get_conv_layer(spatial_dims, dims[i], dims[i + 1], kernel_size=(2, 2, 2), stride=(2, 2, 2), - dropout=dropout, conv_only=True, ), - get_norm_layer(name=("group", {"num_groups": dims[i]}), channels=dims[i + 1]), - ) - self.downsample_layers.append(downsample_layer) - - self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple Transformer blocks - for i in range(4): - stage_blocks = [] - for j in range(depths[i]): - stage_blocks.append(TransformerBlock(input_size=input_size[i], hidden_size=dims[i], proj_size=proj_size[i], num_heads=num_heads, - dropout_rate=transformer_dropout_rate, pos_embed=True)) - self.stages.append(nn.Sequential(*stage_blocks)) - self.hidden_states = [] - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (LayerNorm, nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward_features(self, x): - hidden_states = [] - x = self.downsample_layers[0](x) - x = self.stages[0](x) - - hidden_states.append(x) - - for i in range(1, 4): - x = self.downsample_layers[i](x) - x = self.stages[i](x) - if i == 3: # Reshape the output of the last stage - x = einops.rearrange(x, "b c h w d -> b (h w d) c") - hidden_states.append(x) - return x, hidden_states - - def forward(self, x): - x, hidden_states = self.forward_features(x) - return x, hidden_states - - -class UnetrUpBlock(nn.Module): - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - proj_size: int = 64, - num_heads: int = 4, - out_size: int = 0, - depth: int = 3, - conv_decoder: bool = False, - ) -> None: - """ - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - proj_size: projection size for keys and values in the spatial attention module. - num_heads: number of heads inside each EPA module. - out_size: spatial size for each decoder. - depth: number of blocks for the current decoder stage. - """ - - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - conv_only=True, - is_transposed=True, - ) - - # 4 feature resolution stages, each consisting of multiple residual blocks - self.decoder_block = nn.ModuleList() - - # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block (see suppl. material in the paper) - if conv_decoder == True: - self.decoder_block.append( - UnetResBlock(spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, - norm_name=norm_name, )) - else: - stage_blocks = [] - for j in range(depth): - stage_blocks.append(TransformerBlock(input_size=out_size, hidden_size= out_channels, proj_size=proj_size, num_heads=num_heads, - dropout_rate=0.15, pos_embed=True)) - self.decoder_block.append(nn.Sequential(*stage_blocks)) - - def _init_weights(self, m): - if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, (nn.LayerNorm)): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, inp, skip): - - out = self.transp_conv(inp) - out = out + skip - out = self.decoder_block[0](out) - - return out - - -class UNETR_PP(nn.Module): - """ - UNETR++ based on: "Shaker et al., - UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" - """ - - def __init__(self, params): - """ - Args: - in_channels: dimension of input channels. - out_channels: dimension of output channels. - img_size: dimension of input image. - feature_size: dimension of network feature size. - hidden_size: dimension of the last encoder. - num_heads: number of attention heads. - pos_embed: position embedding layer type. - norm_name: feature normalization type and arguments. - dropout_rate: faction of the input units to drop. - depths: number of blocks for each stage. - dims: number of channel maps for the stages. - conv_op: type of convolution operation. - do_ds: use deep supervision to compute the loss. - - """ - super().__init__() - in_channels = params['in_chns'] - out_channels = params['class_num'] - img_size = params['img_size'] - self.res_mode= params.get("resolution_mode", 1) - feature_size = params.get('feature_size', 16) - hidden_size = params.get('hidden_size', 256) - num_heads = params.get('num_heads', 4) - pos_embed = params.get('pos_embed', "perceptron") - norm_name = params.get('norm_name', "instance") - dropout_rate = params.get('dropout_rate', 0.0) - depths = params.get('depths', [3, 3, 3, 3]) - dims = params.get('dims', [32, 64, 128, 256]) - conv_op = nn.Conv3d - do_ds = params.get('deep_supervise', True) - - self.do_ds = do_ds - self.conv_op = conv_op - self.num_classes = out_channels - if not (0 <= dropout_rate <= 1): - raise AssertionError("dropout_rate should be between 0 and 1.") - - if pos_embed not in ["conv", "perceptron"]: - raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - - kernel_ds = [4, 2, 1] - kernel_d = kernel_ds[self.res_mode] - self.patch_size = (kernel_d, 4, 4) - - self.feat_size = ( - img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages - img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages - ) - - self.hidden_size = hidden_size - - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, - in_channels=in_channels, kernel_size=self.patch_size) - - self.encoder1 = UnetResBlock( - spatial_dims=3, - in_channels=in_channels, - out_channels=feature_size, - kernel_size=3, - stride=1, - norm_name=norm_name, - ) - self.decoder5 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 16, - out_channels=feature_size * 8, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=8 * 8 * 8, - ) - self.decoder4 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 8, - out_channels=feature_size * 4, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=16 * 16 * 16, - ) - self.decoder3 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 4, - out_channels=feature_size * 2, - kernel_size=3, - upsample_kernel_size=2, - norm_name=norm_name, - out_size=32 * 32 * 32, - ) - - self.decoder2 = UnetrUpBlock( - spatial_dims=3, - in_channels=feature_size * 2, - out_channels=feature_size, - kernel_size=3, - upsample_kernel_size= self.patch_size, - norm_name=norm_name, - out_size= kernel_d*32 * 128 * 128, - conv_decoder=True, - ) - self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - # if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) - - def proj_feat(self, x, hidden_size, feat_size): - x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) - x = x.permute(0, 4, 1, 2, 3).contiguous() - return x - - def forward(self, x_in): - x_output, hidden_states = self.unetr_pp_encoder(x_in) - - convBlock = self.encoder1(x_in) - - # Four encoders - enc1 = hidden_states[0] - enc2 = hidden_states[1] - enc3 = hidden_states[2] - enc4 = hidden_states[3] - - # Four decoders - dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) - dec3 = self.decoder5(dec4, enc3) - dec2 = self.decoder4(dec3, enc2) - dec1 = self.decoder3(dec2, enc1) - - out = self.decoder2(dec1, convBlock) - if self.do_ds: - logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] - else: - logits = self.out1(out) - - return logits - - -if __name__ == "__main__": - depths = [128, 64, 32] - for i in range(3): - params = {'in_chns': 4, - 'class_num': 2, - 'img_size': [depths[i], 128, 128], - 'resolution_mode': i - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 4, depths[i], 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net3d/trans3d/unetr_pp_block.py b/pymic/net/net3d/trans3d/unetr_pp_block.py deleted file mode 100644 index 89a8769..0000000 --- a/pymic/net/net3d/trans3d/unetr_pp_block.py +++ /dev/null @@ -1,278 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import numpy as np -import torch -import torch.nn as nn -from typing import Optional, Sequence, Tuple, Union -from monai.networks.blocks.convolutions import Convolution -from monai.networks.layers.factories import Act, Norm -from monai.networks.layers.utils import get_act_layer, get_norm_layer - - -class UnetResBlock(nn.Module): - """ - A skip-connection based module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.downsample = in_channels != out_channels - stride_np = np.atleast_1d(stride) - if not np.all(stride_np == 1): - self.downsample = True - if self.downsample: - self.conv3 = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=stride, dropout=dropout, conv_only=True - ) - self.norm3 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - residual = inp - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - if hasattr(self, "conv3"): - residual = self.conv3(residual) - if hasattr(self, "norm3"): - residual = self.norm3(residual) - out += residual - out = self.lrelu(out) - return out - - -class UnetBasicBlock(nn.Module): - """ - A CNN module module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - ): - super().__init__() - self.conv1 = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - dropout=dropout, - conv_only=True, - ) - self.conv2 = get_conv_layer( - spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True - ) - self.lrelu = get_act_layer(name=act_name) - self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) - - def forward(self, inp): - out = self.conv1(inp) - out = self.norm1(out) - out = self.lrelu(out) - out = self.conv2(out) - out = self.norm2(out) - out = self.lrelu(out) - return out - - -class UnetUpBlock(nn.Module): - """ - An upsampling module that can be used for DynUNet, based on: - `Automated Design of Deep Learning Methods for Biomedical Image Segmentation `_. - `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation `_. - - Args: - spatial_dims: number of spatial dimensions. - in_channels: number of input channels. - out_channels: number of output channels. - kernel_size: convolution kernel size. - stride: convolution stride. - upsample_kernel_size: convolution kernel size for transposed convolution layers. - norm_name: feature normalization type and arguments. - act_name: activation layer type and arguments. - dropout: dropout probability. - trans_bias: transposed convolution bias. - - """ - - def __init__( - self, - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int], - stride: Union[Sequence[int], int], - upsample_kernel_size: Union[Sequence[int], int], - norm_name: Union[Tuple, str], - act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), - dropout: Optional[Union[Tuple, str, float]] = None, - trans_bias: bool = False, - ): - super().__init__() - upsample_stride = upsample_kernel_size - self.transp_conv = get_conv_layer( - spatial_dims, - in_channels, - out_channels, - kernel_size=upsample_kernel_size, - stride=upsample_stride, - dropout=dropout, - bias=trans_bias, - conv_only=True, - is_transposed=True, - ) - self.conv_block = UnetBasicBlock( - spatial_dims, - out_channels + out_channels, - out_channels, - kernel_size=kernel_size, - stride=1, - dropout=dropout, - norm_name=norm_name, - act_name=act_name, - ) - - def forward(self, inp, skip): - # number of channels for skip should equals to out_channels - out = self.transp_conv(inp) - out = torch.cat((out, skip), dim=1) - out = self.conv_block(out) - return out - - -class UnetOutBlock(nn.Module): - def __init__( - self, spatial_dims: int, in_channels: int, out_channels: int, dropout: Optional[Union[Tuple, str, float]] = None - ): - super().__init__() - self.conv = get_conv_layer( - spatial_dims, in_channels, out_channels, kernel_size=1, stride=1, dropout=dropout, bias=True, conv_only=True - ) - - def forward(self, inp): - return self.conv(inp) - - -def get_conv_layer( - spatial_dims: int, - in_channels: int, - out_channels: int, - kernel_size: Union[Sequence[int], int] = 3, - stride: Union[Sequence[int], int] = 1, - act: Optional[Union[Tuple, str]] = Act.PRELU, - norm: Union[Tuple, str] = Norm.INSTANCE, - dropout: Optional[Union[Tuple, str, float]] = None, - bias: bool = False, - conv_only: bool = True, - is_transposed: bool = False, -): - padding = get_padding(kernel_size, stride) - output_padding = None - if is_transposed: - output_padding = get_output_padding(kernel_size, stride, padding) - return Convolution( - spatial_dims, - in_channels, - out_channels, - strides=stride, - kernel_size=kernel_size, - act=act, - norm=norm, - dropout=dropout, - bias=bias, - conv_only=conv_only, - is_transposed=is_transposed, - padding=padding, - output_padding=output_padding, - ) - - -def get_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = (kernel_size_np - stride_np + 1) / 2 - if np.min(padding_np) < 0: - raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") - padding = tuple(int(p) for p in padding_np) - - return padding if len(padding) > 1 else padding[0] - - -def get_output_padding( - kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] -) -> Union[Tuple[int, ...], int]: - kernel_size_np = np.atleast_1d(kernel_size) - stride_np = np.atleast_1d(stride) - padding_np = np.atleast_1d(padding) - - out_padding_np = 2 * padding_np + stride_np - kernel_size_np - if np.min(out_padding_np) < 0: - raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") - out_padding = tuple(int(p) for p in out_padding_np) - - return out_padding if len(out_padding) > 1 else out_padding[0] diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a17bcb8..e383e77 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate + class ConvBlock(nn.Module): """ @@ -56,22 +57,32 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param trilinear: (bool) Use trilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - trilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode=2): super(UpBlock, self).__init__() - self.trilinear = trilinear - if trilinear: - self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "trilinear"] + if(up_mode > 2): + raise ValueError("The upsample mode should be 0-2, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose3d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv3d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.trilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,9 +140,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - :param multiscale_pred: (bool) Get multi-scale prediction. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -140,21 +152,25 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.trilinear = self.params.get('trilinear', True) + self.up_mode = self.params.get('up_mode', 2) self.mul_pred = self.params.get('multiscale_pred', False) - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred): self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage def forward(self, x): if(len(self.ft_chns) == 5): @@ -169,7 +185,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.mul_pred): + if(self.mul_pred and self.stage == 'train'): output1 = self.out_conv1(x_d1) output2 = self.out_conv2(x_d2) output3 = self.out_conv3(x_d3) @@ -196,77 +212,38 @@ class UNet3D(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param trilinear: (bool) Using trilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Trilinear`). The default value + is 2 (`Trilinear`). :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(UNet3D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.trilinear = self.params['trilinear'] - self.mul_pred = self.params['multiscale_pred'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = self.dropout[3], trilinear=self.trilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = self.dropout[2], trilinear=self.trilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = self.dropout[1], trilinear=self.trilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = self.dropout[0], trilinear=self.trilinear) - - self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1) + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.stage = 'train' + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] + f = self.encoder(x) + output = self.decoder(f) return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'class_num': 2, - 'feature_chns':[2, 8, 32, 64], - 'dropout' : [0, 0, 0, 0.5], - 'trilinear': True, - 'multiscale_pred': False} - Net = UNet3D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 96, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - y = y.detach().numpy() - print(y.shape) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py index c681de9..db8b10b 100644 --- a/pymic/net_run/agent_preprocess.py +++ b/pymic/net_run/agent_preprocess.py @@ -8,8 +8,8 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.transform.trans_dict import TransformDict - - +from pymic.net_run.agent_abstract import seed_torch +from pymic.net_run.self_sup.util import volume_fusion class PreprocessAgent(object): def __init__(self, config): @@ -19,9 +19,14 @@ def __init__(self, config): self.task_type = config['dataset']['task_type'] self.dataloader = None self.dataloader_unlab= None + + deterministic = config['dataset'].get('deterministic', True) + if(deterministic): + random_seed = config['dataset'].get('random_seed', 1) + seed_torch(random_seed) def get_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] + root_dir = self.config['dataset']['data_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']["transform"] @@ -40,6 +45,8 @@ def get_dataset_from_config(self): data_csv = self.config['dataset'].get('data_csv', None) data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + batch_size = self.config['dataset'].get('batch_size', 1) + data_shuffle = self.config['dataset'].get('data_shuffle', False) if(data_csv is not None): dataset = NiftyDataset(root_dir = root_dir, csv_file = data_csv, @@ -48,7 +55,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader = torch.utils.data.DataLoader(dataset, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) if(data_csv_unlab is not None): dataset_unlab = NiftyDataset(root_dir = root_dir, @@ -58,7 +65,7 @@ def get_dataset_from_config(self): transform = data_transform, task = self.task_type) self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, - batch_size = 1, shuffle=False, num_workers= 8, + batch_size = batch_size, shuffle=data_shuffle, num_workers= 8, worker_init_fn=None, generator = torch.Generator()) def run(self): @@ -66,39 +73,38 @@ def run(self): Do preprocessing for labeled and unlabeled data. """ self.get_dataset_from_config() - out_dir = self.config['dataset']['output_dir'] + out_dir = self.config['dataset']['output_dir'] + modal_num = self.config['dataset']['modal_num'] + if(not os.path.isdir(out_dir)): + os.mkdir(out_dir) + batch_operation = self.config['dataset'].get('batch_operation', None) for dataloader in [self.dataloader, self.dataloader_unlab]: - for item in dataloader: - img = item['image'][0] # the batch size is 1 - # save differnt modaliteis - img_names = item['names'] - spacing = [x.numpy()[0] for x in item['spacing']] - for i in range(img.shape[0]): - image_name = out_dir + "/" + img_names[i][0] - print(image_name) - save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) - if('label' in item): - lab = item['label'][0] - label_name = out_dir + "/" + img_names[-1][0] - print(label_name) - save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) - -def main(): - """ - The main function for data preprocessing. - """ - if(len(sys.argv) < 2): - print('Number of arguments should be 2. e.g.') - print(' pymic_preprocess config.cfg') - exit() - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - agent = PreprocessAgent(config) - agent.run() + if(dataloader is None): + continue + for data in dataloader: + inputs = data['image'] + labels = data.get('label', None) + img_names = data['names'] + if(len(img_names) == modal_num): # for unlabeled dataset + lab_names = [item.replace(".nii.gz", "_lab.nii.gz") for item in img_names[0]] + else: + lab_names = img_names[-1] + B, C = inputs.shape[0], inputs.shape[1] + spacing = [x.numpy()[0] for x in data['spacing']] + + if(batch_operation is not None and 'VolumeFusion' in batch_operation): + class_num = self.config['dataset']['VolumeFusion_cls_num'.lower()] + block_range = self.config['dataset']['VolumeFusion_block_range'.lower()] + size_min = self.config['dataset']['VolumeFusion_size_min'.lower()] + size_max = self.config['dataset']['VolumeFusion_size_max'.lower()] + inputs, labels = volume_fusion(inputs, class_num - 1, block_range, size_min, size_max) -if __name__ == "__main__": - main() - + for b in range(B): + for c in range(C): + image_name = out_dir + "/" + img_names[c][b] + print(image_name) + save_nd_array_as_image(inputs[b][c], image_name, reference_name = None, spacing=spacing) + if(labels is not None): + label_name = out_dir + "/" + lab_names[b] + print(label_name) + save_nd_array_as_image(labels[b][0], label_name, reference_name = None, spacing=spacing) diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 2d6d489..2d61d0d 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -177,6 +177,7 @@ def training(self): inputs, labels_prob = mixup(inputs, labels_prob) # for debug + # print("current iteration", it) # if(it > 10): # break # for i in range(inputs.shape[0]): diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index b0190ad..e0e466e 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -104,7 +104,7 @@ def __infer_with_sliding_window(self, image): weight = torch.zeros(output_shape).to(image.device) temp_w = self.__get_gaussian_weight_map(window_size) temp_w = np.broadcast_to(temp_w, [batch_size, class_num] + window_size) - temp_w = torch.from_numpy(temp_w).to(image.device) + temp_w = torch.from_numpy(np.array(temp_w)).to(image.device) temp_in_shape = img_full_shape[:2] + window_size tempx = torch.ones(temp_in_shape).to(image.device) out_num, scale_list = self.__get_prediction_number_and_scales(tempx) diff --git a/pymic/net_run/noisy_label/nll_clslsr.py b/pymic/net_run/noisy_label/nll_clslsr.py index 0148621..c977eba 100644 --- a/pymic/net_run/noisy_label/nll_clslsr.py +++ b/pymic/net_run/noisy_label/nll_clslsr.py @@ -142,9 +142,9 @@ def test_time_dropout(m): print(gt.shape, pred_cat.shape) conf = get_confident_map(gt, pred_cat) conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 - save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + save_dir = self.config['dataset']['train_dir'] + "/slsr_conf" for idx in range(len(filename_list)): - filename = filename_list[idx][0].split('/')[-1] + filename = filename_list[idx][0][0].split('/')[-1] conf_map = Image.fromarray(conf[idx]) dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) @@ -152,32 +152,29 @@ def test_time_dropout(m): def get_confidence_map(cfg_file): config = parse_config(cfg_file) config = synchronize_config(config) + agent = NLLCLSLSR(config, 'test') - # set dataset - transform_names = config['dataset']['valid_transform'] + # set customized dataset for testing, i.e,. inference with training images + trans_names, trans_params = agent.get_transform_names_and_parameters('valid') transform_list = [] - transform_dict = TransformDict - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in transform_dict): + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: + if(name not in agent.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = transform_dict[name](transform_param) + one_transform = agent.transform_dict[name](trans_params) transform_list.append(one_transform) - data_transform = transforms.Compose(transform_list) + data_transform = transforms.Compose(transform_list) csv_file = config['dataset']['train_csv'] modal_num = config['dataset'].get('modal_num', 1) - dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + stage_dir = config['dataset']['train_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform, + task = agent.task_type) - agent = NLLCLSLSR(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.create_dataset() diff --git a/pymic/net_run/noisy_label/nll_co_teaching.py b/pymic/net_run/noisy_label/nll_co_teaching.py index ec8e230..c60616e 100644 --- a/pymic/net_run/noisy_label/nll_co_teaching.py +++ b/pymic/net_run/noisy_label/nll_co_teaching.py @@ -18,22 +18,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 - class NLLCoTeaching(SegmentationAgent): """ Co-teaching for noisy-label learning. @@ -58,14 +42,6 @@ def __init__(self, config, stage = 'train'): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] diff --git a/pymic/net_run/noisy_label/nll_dast.py b/pymic/net_run/noisy_label/nll_dast.py index 1921e9c..a90747c 100644 --- a/pymic/net_run/noisy_label/nll_dast.py +++ b/pymic/net_run/noisy_label/nll_dast.py @@ -117,31 +117,27 @@ def get_noisy_dataset_from_config(self): """ Create a dataset for images with noisy labels based on configuraiton. """ - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) - transform_names = self.config['dataset']['train_transform'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: + trans_names, trans_params = self.get_transform_names_and_parameters('train') + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + modal_num = self.config['dataset'].get('modal_num', 1) csv_file = self.config['dataset'].get('train_csv_noise', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = self.config['dataset']['train_dir'], csv_file = csv_file, modal_num = modal_num, with_label= True, - transform = data_transform ) + transform = data_transform , + task = self.task_type) return dataset + def create_dataset(self): super(NLLDAST, self).create_dataset() if(self.stage == 'train'): diff --git a/pymic/net_run/noisy_label/nll_trinet.py b/pymic/net_run/noisy_label/nll_trinet.py index 25c90cf..64d87b6 100644 --- a/pymic/net_run/noisy_label/nll_trinet.py +++ b/pymic/net_run/noisy_label/nll_trinet.py @@ -17,24 +17,6 @@ from pymic.util.parse_config import * from pymic.util.ramps import get_rampup_ratio -class TriNet(nn.Module): - def __init__(self, params): - super(TriNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - self.net3 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - out3 = self.net3(x) - - if(self.training): - return out1, out2, out3 - else: - return (out1 + out2 + out3) / 3 - class NLLTriNet(SegmentationAgent): """ Implementation of trinet for learning from noisy samples for @@ -56,14 +38,6 @@ class NLLTriNet(SegmentationAgent): def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = TriNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): prob = nn.Softmax(dim = 1)(pred) prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py index f2bbe0f..3b34887 100644 --- a/pymic/net_run/preprocess.py +++ b/pymic/net_run/preprocess.py @@ -1,82 +1,30 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -import logging import os import sys -import shutil from datetime import datetime -from pymic import TaskType from pymic.util.parse_config import * -from pymic.net_run.agent_cls import ClassificationAgent -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run.semi_sup import SSLMethodDict -from pymic.net_run.weak_sup import WSLMethodDict -from pymic.net_run.self_sup import SelfSupMethodDict -from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent +from pymic.net_run.agent_preprocess import PreprocessAgent -def get_seg_rec_agent(config, sup_type): - assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) - if(sup_type == 'fully_sup'): - logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') - elif(sup_type == 'semi_sup'): - logging.info("\n********** Semi Supervised Learning **********\n") - method = config['semi_supervised_learning']['method_name'] - agent = SSLMethodDict[method](config, 'train') - elif(sup_type == 'weak_sup'): - logging.info("\n********** Weakly Supervised Learning **********\n") - method = config['weakly_supervised_learning']['method_name'] - agent = WSLMethodDict[method](config, 'train') - elif(sup_type == 'noisy_label'): - logging.info("\n********** Noisy Label Learning **********\n") - method = config['noisy_label_learning']['method_name'] - agent = NLLMethodDict[method](config, 'train') - elif(sup_type == 'self_sup'): - logging.info("\n********** Self Supervised Learning **********\n") - method = config['self_supervised_learning']['method_name'] - agent = SelfSupMethodDict[method](config, 'train') - else: - raise ValueError("undefined supervision type: {0:}".format(sup_type)) - return agent def main(): """ - The main function for running a network for training. + The main function for data preprocessing. """ if(len(sys.argv) < 2): print('Number of arguments should be 2. e.g.') - print(' pymic_train config.cfg') + print(' pymic_preprocess config.cfg') exit() cfg_file = str(sys.argv[1]) if(not os.path.isfile(cfg_file)): raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - datetime_str = str(datetime.now())[:-7].replace(":", "_") - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - task = config['dataset']['task_type'] - if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): - agent = ClassificationAgent(config, 'train') - else: - sup_type = config['dataset'].get('supervise_type', 'fully_sup') - agent = get_seg_rec_agent(config, sup_type) - + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) agent.run() if __name__ == "__main__": main() + diff --git a/pymic/net_run/semi_sup/ssl_mt.py b/pymic/net_run/semi_sup/ssl_mt.py index 2a2abb8..409af19 100644 --- a/pymic/net_run/semi_sup/ssl_mt.py +++ b/pymic/net_run/semi_sup/ssl_mt.py @@ -106,7 +106,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/semi_sup/ssl_uamt.py b/pymic/net_run/semi_sup/ssl_uamt.py index 6222fe3..053a012 100644 --- a/pymic/net_run/semi_sup/ssl_uamt.py +++ b/pymic/net_run/semi_sup/ssl_uamt.py @@ -108,7 +108,7 @@ def training(self): alpha = ssl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 50a5fb7..3a4571f 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -14,7 +14,6 @@ from pymic.net_run.weak_sup import WSLMethodDict from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) diff --git a/pymic/net_run/weak_sup/wsl_ustm.py b/pymic/net_run/weak_sup/wsl_ustm.py index 0ea3fbc..31a6644 100644 --- a/pymic/net_run/weak_sup/wsl_ustm.py +++ b/pymic/net_run/weak_sup/wsl_ustm.py @@ -125,7 +125,7 @@ def training(self): alpha = wsl_cfg.get('ema_decay', 0.99) alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha) for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()): - ema_param.data.mul_(alpha).add_(1 - alpha, param.data) + ema_param.data.mul_(alpha).add(param.data, alpha = 1.0 - alpha) train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 5f0e4ec..643c12e 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -38,7 +38,7 @@ def __init__(self, params): self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), None) self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index a9b114b..82939ce 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -286,63 +286,65 @@ def evaluation(config): label_list = [label_list] label_fuse = config.get('label_fuse', False) output_name = config.get('output_name', None) - gt_root = config['ground_truth_folder_root'] - seg_root = config['segmentation_folder_root'] + gt_dir = config['ground_truth_folder'] + seg_dirs = config['segmentation_folder'] image_pair_csv = config.get('evaluation_image_pair', None) + if(not isinstance(seg_dirs, (tuple, list))): + seg_dirs = [seg_dirs] if(image_pair_csv is not None): image_pair = pd.read_csv(image_pair_csv) gt_names, seg_names = image_pair.iloc[:, 0], image_pair.iloc[:, 1] else: - seg_names = sorted(os.listdir(seg_root)) + seg_names = sorted(os.listdir(seg_dirs[0])) seg_names = [item for item in seg_names if is_image_name(item)] gt_names = seg_names + for seg_dir in seg_dirs: + for metric in metric_list: + print(metric) + score_all_data = [] + name_score_list= [] + for i in range(len(gt_names)): + gt_full_name = join(gt_dir, gt_names[i]) + seg_full_name = join(seg_dir, seg_names[i]) + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_names[i]] + score_vector) + print(seg_names[i], score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) - for metric in metric_list: - print(metric) - score_all_data = [] - name_score_list= [] - for i in range(len(gt_names)): - gt_full_name = join(gt_root, gt_names[i]) - seg_full_name = join(seg_root, seg_names[i]) - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_names[i]] + score_vector) - print(seg_names[i], score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) - - # save the result as csv - if(output_name is None): - metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - else: - metric_output_name = output_name - with open(metric_output_name, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + # save the result as csv + if(output_name is None): + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_dir, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index c813e5d..158569c 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -38,6 +38,18 @@ def get_ND_bounding_box(volume, margin = None): bb_max[i] = min(bb_max[i] + margin[i], input_shape[i]) return bb_min, bb_max +def get_human_region_from_ct(image, threshold_i = -600, threshold_z = 0.6): + input_shape = image.shape + mask = np.asarray(image > threshold_i) + mask2d = np.mean(mask, axis = 0) > threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bb_min = [0] + bbmin + bb_max = list(input_shape[:1]) + bbmax + return bb_min, bb_max + def crop_ND_volume_with_bounding_box(volume, bb_min, bb_max): """ Extract a subregion form an ND image. diff --git a/pymic/util/parse_config.py b/pymic/util/parse_config.py index a12cc76..0e38b91 100644 --- a/pymic/util/parse_config.py +++ b/pymic/util/parse_config.py @@ -102,24 +102,35 @@ def parse_config(filename): def synchronize_config(config): data_cfg = config['dataset'] - net_cfg = config['network'] - # data_cfg["modal_num"] = net_cfg["in_chns"] data_cfg["task_type"] = TaskDict[data_cfg["task_type"]] - data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] - if "PartialLabelToProbability" in data_cfg['train_transform']: + if('network' in config): + net_cfg = config['network'] + # data_cfg["modal_num"] = net_cfg["in_chns"] + data_cfg["LabelToProbability_class_num".lower()] = net_cfg["class_num"] + transform = [] + if('transform' in data_cfg and data_cfg['transform'] is not None): + transform.extend(data_cfg['transform']) + if('train_transform' in data_cfg and data_cfg['train_transform'] is not None): + transform.extend(data_cfg['train_transform']) + if('valid_transform' in data_cfg and data_cfg['valid_transform'] is not None): + transform.extend(data_cfg['valid_transform']) + if('test_transform' in data_cfg and data_cfg['test_transform'] is not None): + transform.extend(data_cfg['test_transform']) + if ( "PartialLabelToProbability" in transform and 'network' in config): data_cfg["PartialLabelToProbability_class_num".lower()] = net_cfg["class_num"] patch_size = data_cfg.get('patch_size', None) if(patch_size is not None): - if('Pad' in data_cfg['train_transform']): + if('Pad' in transform and 'Pad_output_size'.lower() not in data_cfg): data_cfg['Pad_output_size'.lower()] = patch_size - if('CenterCrop' in data_cfg['train_transform']): + if('CenterCrop' in transform and 'CenterCrop_output_size'.lower() not in data_cfg): data_cfg['CenterCrop_output_size'.lower()] = patch_size - if('RandomCrop' in data_cfg['train_transform']): + if('RandomCrop' in transform and 'RandomCrop_output_size'.lower() not in data_cfg): data_cfg['RandomCrop_output_size'.lower()] = patch_size - if('RandomResizedCrop' in data_cfg['train_transform']): + if('RandomResizedCrop' in transform and \ + 'RandomResizedCrop_output_size'.lower() not in data_cfg): data_cfg['RandomResizedCrop_output_size'.lower()] = patch_size config['dataset'] = data_cfg - config['network'] = net_cfg + # config['network'] = net_cfg return config def logging_config(config): diff --git a/pymic/util/preprocess.py b/pymic/util/preprocess.py deleted file mode 100644 index c0dc9a1..0000000 --- a/pymic/util/preprocess.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -import os -import numpy as np -import SimpleITK as sitk -from pymic.io.image_read_write import load_image_as_nd_array -from pymic.transform.trans_dict import TransformDict -from pymic.util.parse_config import parse_config - -def get_transform_list(trans_config_file): - """ - Create a list of transforms given a configuration file. - """ - config = parse_config(trans_config_file) - transform_list = [] - - transform_param = config['dataset'] - transform_param['task'] = 'segmentation' - transform_names = config['dataset']['transform'] - for name in transform_names: - print(name) - if(name not in TransformDict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = TransformDict[name](transform_param) - transform_list.append(one_transform) - return transform_list - -def preprocess_with_transform(transforms, img_in_name, img_out_name, - lab_in_name = None, lab_out_name = None): - """ - Using a list of data transforms for preprocessing, - such as image normalization, cropping, etc. - TODO: support multip-modality preprocessing. - - :param transforms: (list) A list of transform objects. - :param img_in_name: (str) Input file name. - :param img_out_name: (str) Output file name. - :param lab_in_name: (optional, str) If None, load the image's - corresponding label for preprocessing as well. - :param lab_out_name: (optional, str) The output label name. - """ - image_dict = load_image_as_nd_array(img_in_name) - sample = {'image': np.asarray(image_dict['data_array'], np.float32), - 'origin':image_dict['origin'], - 'spacing': image_dict['spacing'], - 'direction':image_dict['direction']} - if(lab_in_name is not None): - label_dict = load_image_as_nd_array(lab_in_name) - sample['label'] = label_dict['data_array'] - for transform in transforms: - sample = transform(sample) - - out_img = sitk.GetImageFromArray(sample['image'][0]) - out_img.SetSpacing(sample['spacing']) - out_img.SetOrigin(sample['origin']) - out_img.SetDirection(sample['direction']) - sitk.WriteImage(out_img, img_out_name) - if(lab_in_name is not None and lab_out_name is not None): - out_lab = sitk.GetImageFromArray(sample['label'][0]) - out_lab.CopyInformation(out_img) - sitk.WriteImage(out_lab, lab_out_name) - - - diff --git a/setup.py b/setup.py index 36daf9a..ebb738f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.4.0", + version = "0.4.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -41,6 +41,7 @@ python_requires = '>=3.6', entry_points = { 'console_scripts': [ + 'pymic_preprocess = pymic.net_run.preprocess:main', 'pymic_train = pymic.net_run.train:main', 'pymic_test = pymic.net_run.predict:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main',