diff --git a/.gitignore b/.gitignore index f7da7ac..639f873 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ dist/* *egg*/* *stop* files.txt +pymic/test/runs/* # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks diff --git a/docs/source/api.rst b/docs/source/api.rst index 206000d..d09809c 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,8 +9,5 @@ API pymic.loss pymic.net pymic.net_run - pymic.net_run_nll - pymic.net_run_ssl - pymic.net_run_wsl pymic.transform pymic.util \ No newline at end of file diff --git a/docs/source/pymic.net.net2d.rst b/docs/source/pymic.net.net2d.rst index d978dfe..bd54bc6 100644 --- a/docs/source/pymic.net.net2d.rst +++ b/docs/source/pymic.net.net2d.rst @@ -52,6 +52,14 @@ pymic.net.net2d.unet2d\_dual\_branch module :undoc-members: :show-inheritance: +pymic.net.net2d.unet2d\_mcnet module +------------------------------------------- + +.. automodule:: pymic.net.net2d.unet2d_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net.net2d.unet2d\_nest module ----------------------------------- diff --git a/docs/source/pymic.net_run.semi_sup.rst b/docs/source/pymic.net_run.semi_sup.rst index 6ed157d..15692b2 100644 --- a/docs/source/pymic.net_run.semi_sup.rst +++ b/docs/source/pymic.net_run.semi_sup.rst @@ -28,6 +28,14 @@ pymic.net\_run.semi\_sup.ssl\_cps module :undoc-members: :show-inheritance: +pymic.net\_run.semi\_sup.ssl\_mcnet module +---------------------------------------- + +.. automodule:: pymic.net_run.semi_sup.ssl_mcnet + :members: + :undoc-members: + :show-inheritance: + pymic.net\_run.semi\_sup.ssl\_em module --------------------------------------- diff --git a/pymic/io/h5_dataset.py b/pymic/io/h5_dataset.py index 02f94f3..34fa1a4 100644 --- a/pymic/io/h5_dataset.py +++ b/pymic/io/h5_dataset.py @@ -8,8 +8,9 @@ import pandas as pd from torch.utils.data import Dataset from torch.utils.data.sampler import Sampler +from pymic import TaskType -class H5DataSet(Dataset): +class H5DataSet_backup(Dataset): """ Dataset for loading images stored in h5 format. It generates 4D tensors with dimention order [C, D, H, W] for 3D images, and @@ -39,7 +40,9 @@ def __getitem__(self, idx): if self.transform: sample = self.transform(sample) return sample - + + + class TwoStreamBatchSampler(Sampler): """Iterate two sets of indices diff --git a/pymic/io/image_read_write.py b/pymic/io/image_read_write.py index cb65e19..3aa87bd 100644 --- a/pymic/io/image_read_write.py +++ b/pymic/io/image_read_write.py @@ -23,11 +23,9 @@ def load_nifty_volume_as_4d_array(filename): spacing = img_obj.GetSpacing() direction = img_obj.GetDirection() shape = data_array.shape - if(len(shape) == 4): - assert(shape[3] == 1) - elif(len(shape) == 3): + if(len(shape) == 3): data_array = np.expand_dims(data_array, axis = 0) - else: + elif(len(shape) > 4 or len(shape) < 3): raise ValueError("unsupported image dim: {0:}".format(len(shape))) output = {} output['data_array'] = data_array @@ -81,10 +79,10 @@ def load_image_as_nd_array(image_name): image_name.endswith(".tif") or image_name.endswith(".png")): image_dict = load_rgb_image_as_3d_array(image_name) else: - raise ValueError("unsupported image format") + raise ValueError("unsupported image format: {0:}".format(image_name)) return image_dict -def save_array_as_nifty_volume(data, image_name, reference_name = None): +def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a numpy array as nifty image @@ -92,6 +90,7 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): :param image_name: (str) The ouput file name. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ img = sitk.GetImageFromArray(data) if(reference_name is not None): @@ -99,7 +98,13 @@ def save_array_as_nifty_volume(data, image_name, reference_name = None): #img.CopyInformation(img_ref) img.SetSpacing(img_ref.GetSpacing()) img.SetOrigin(img_ref.GetOrigin()) - img.SetDirection(img_ref.GetDirection()) + direction0 = img_ref.GetDirection() + direction1 = img.GetDirection() + if(len(direction0) == len(direction1)): + img.SetDirection(direction0) + else: + nifty_spacing = spacing[1:] + spacing[:1] + img.SetSpacing(nifty_spacing) sitk.WriteImage(img, image_name) def save_array_as_rgb_image(data, image_name): @@ -118,7 +123,7 @@ def save_array_as_rgb_image(data, image_name): img = Image.fromarray(data) img.save(image_name) -def save_nd_array_as_image(data, image_name, reference_name = None): +def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]): """ Save a 3D or 2D numpy array as medical image or RGB image @@ -126,13 +131,14 @@ def save_nd_array_as_image(data, image_name, reference_name = None): [H, W, 3] or [H, W]. :param reference_name: (str) File name of the reference image of which meta information is used. + :param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided. """ data_dim = len(data.shape) assert(data_dim == 2 or data_dim == 3) if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or image_name.endswith(".mha")): assert(data_dim == 3) - save_array_as_nifty_volume(data, image_name, reference_name) + save_array_as_nifty_volume(data, image_name, reference_name, spacing) elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or image_name.endswith(".tif") or image_name.endswith(".png")): diff --git a/pymic/io/nifty_dataset.py b/pymic/io/nifty_dataset.py index 9812d13..aefe4da 100644 --- a/pymic/io/nifty_dataset.py +++ b/pymic/io/nifty_dataset.py @@ -3,11 +3,9 @@ import logging import os -import torch import pandas as pd import numpy as np -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms, utils +from torch.utils.data import Dataset from pymic import TaskType from pymic.io.image_read_write import load_image_as_nd_array @@ -38,7 +36,8 @@ def __init__(self, root_dir, csv_file, modal_num = 1, if('label' not in csv_keys): logging.warning("`label` section is not found in the csv file {0:}".format( csv_file) + "\n -- This is only allowed for self-supervised learning" + - "\n -- when `SelfSuperviseLabel` is used in the transform.") + "\n -- when `SelfSuperviseLabel` is used in the transform, or when" + + "\n -- loading the unlabeled data for preprocessing.") self.with_label = False self.image_weight_idx = None self.pixel_weight_idx = None @@ -52,15 +51,15 @@ def __len__(self): def __getlabel__(self, idx): csv_keys = list(self.csv_items.keys()) - label_idx = csv_keys.index('label') - label_name = "{0:}/{1:}".format(self.root_dir, - self.csv_items.iloc[idx, label_idx]) - label = load_image_as_nd_array(label_name)['data_array'] + label_idx = csv_keys.index('label') + label_name = self.csv_items.iloc[idx, label_idx] + label_name_full = "{0:}/{1:}".format(self.root_dir, label_name) + label = load_image_as_nd_array(label_name_full)['data_array'] if(self.task == TaskType.SEGMENTATION): label = np.asarray(label, np.int32) elif(self.task == TaskType.RECONSTRUCTION): label = np.asarray(label, np.float32) - return label + return label, label_name def __get_pixel_weight__(self, idx): weight_name = "{0:}/{1:}".format(self.root_dir, @@ -69,6 +68,25 @@ def __get_pixel_weight__(self, idx): weight = np.asarray(weight, np.float32) return weight + # def __getitem__(self, idx): + # sample_name = self.csv_items.iloc[idx, 0] + # h5f = h5py.File(self.root_dir + '/' + sample_name, 'r') + # image = np.asarray(h5f['image'][:], np.float32) + + # # this a temporaory process, will be delieted later + # if(len(image.shape) == 3 and image.shape[0] > 1): + # image = np.expand_dims(image, 0) + # sample = {'image': image, 'names':sample_name} + + # if('label' in h5f): + # label = np.asarray(h5f['label'][:], np.uint8) + # if(len(label.shape) == 3 and label.shape[0] > 1): + # label = np.expand_dims(label, 0) + # sample['label'] = label + # if self.transform: + # sample = self.transform(sample) + # return sample + def __getitem__(self, idx): names_list, image_list = [], [] for i in range (self.modal_num): @@ -80,12 +98,14 @@ def __getitem__(self, idx): image_list.append(image_data) image = np.concatenate(image_list, axis = 0) image = np.asarray(image, np.float32) - sample = {'image': image, 'names' : names_list[0], + + sample = {'image': image, 'names' : names_list, 'origin':image_dict['origin'], 'spacing': image_dict['spacing'], 'direction':image_dict['direction']} if (self.with_label): - sample['label'] = self.__getlabel__(idx) + sample['label'], label_name = self.__getlabel__(idx) + sample['names'].append(label_name) assert(image.shape[1:] == sample['label'].shape[1:]) if (self.image_weight_idx is not None): sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx] diff --git a/pymic/loss/seg/abstract.py b/pymic/loss/seg/abstract.py index f42d816..68643e8 100644 --- a/pymic/loss/seg/abstract.py +++ b/pymic/loss/seg/abstract.py @@ -16,9 +16,20 @@ class AbstractSegLoss(nn.Module): def __init__(self, params = None): super(AbstractSegLoss, self).__init__() if(params is None): - self.softmax = True + self.acti_func = 'softmax' else: - self.softmax = params.get('loss_softmax', True) + self.acti_func = params.get('loss_acti_func', 'softmax') + + def get_activated_prediction(self, p, acti_func = 'softmax'): + if(acti_func == "softmax"): + p = nn.Softmax(dim = 1)(p) + elif(acti_func == "tanh"): + p = nn.Tanh()(p) + elif(acti_func == "sigmoid"): + p = nn.Sigmoid()(p) + else: + raise ValueError("activation for output is not supported: {0:}".format(acti_func)) + return p def forward(self, loss_input_dict): """ diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index 9524d57..4edbbc3 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -13,8 +13,10 @@ class CrossEntropyLoss(AbstractSegLoss): The parameters should be written in the `params` dictionary, and it has the following fields: - :param `loss_softmax`: (optional, bool) - Apply softmax to the prediction of network or not. Default is True. + :param `loss_acti_func`: (optional, string) + Apply an activation function to the prediction of network or not, for example, + 'softmax' for image segmentation tasks, 'tanh' for reconstruction tasks, and None + means no activation is used. """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) @@ -27,8 +29,9 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) + predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -74,8 +77,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y diff --git a/pymic/loss/seg/dice.py b/pymic/loss/seg/dice.py index 2c2df32..c423c2c 100644 --- a/pymic/loss/seg/dice.py +++ b/pymic/loss/seg/dice.py @@ -25,8 +25,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): @@ -52,8 +52,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = 1.0 - predict[:, :1, :, :, :] soft_y = 1.0 - soft_y[:, :1, :, :, :] predict = reshape_tensor_to_2D(predict) @@ -76,8 +76,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) num_class = list(predict.size())[1] @@ -115,8 +115,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) @@ -149,8 +149,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/exp_log.py b/pymic/loss/seg/exp_log.py index c1b3f00..8c0d494 100644 --- a/pymic/loss/seg/exp_log.py +++ b/pymic/loss/seg/exp_log.py @@ -32,8 +32,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) diff --git a/pymic/loss/seg/mse.py b/pymic/loss/seg/mse.py index 5b657c5..eb53af4 100644 --- a/pymic/loss/seg/mse.py +++ b/pymic/loss/seg/mse.py @@ -19,8 +19,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mse = torch.square(predict - soft_y) mse = torch.mean(mse) return mse @@ -44,8 +44,8 @@ def forward(self, loss_input_dict): if(isinstance(predict, (list, tuple))): predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) + if(self.acti_func is not None): + predict = self.get_activated_prediction(predict, self.acti_func) mae = torch.abs(predict - soft_y) if(weight is None): mae = torch.mean(mae) diff --git a/pymic/net/multi_net.py b/pymic/net/multi_net.py new file mode 100644 index 0000000..78209b1 --- /dev/null +++ b/pymic/net/multi_net.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import torch.nn as nn + +class MultiNet(nn.Module): + ''' + A combination of multiple networks. + Parameters should be saved in the `params` dictionary. + + :param `net_names`: (list) A list of network class name. + :param `infer_mode`: (int) Mode for inference. 0: only use the first network. + 1: taking an average of all the networks. + ''' + def __init__(self, net_dict, params): + super(MultiNet, self).__init__() + net_names = params['net_type'] # should be a list of network class name + self.output_mode = params.get('infer_mode', 0) + self.networks = nn.ModuleList([net_dict[item](params) for item in net_names]) + + def forward(self, x): + if(self.training): + output = [net(x) for net in self.networks] + else: + output = self.networks[0](x) + if(self.output_mode == 1): + for i in range(1, len(self.networks)): + output += self.networks[i](x) + output = output / len(self.networks) + return output + \ No newline at end of file diff --git a/pymic/net/net2d/canet_module.py b/pymic/net/net2d/canet_module.py new file mode 100644 index 0000000..097a4f1 --- /dev/null +++ b/pymic/net/net2d/canet_module.py @@ -0,0 +1,578 @@ +# -*- coding: utf-8 -*- +""" +Building blcoks for CA-Net. + +Oringinal file is on `Github. +`_ +""" + +from __future__ import print_function, division +import torch +import torch.nn as nn +import functools +from torch.nn import functional as F + + +class conv_block(nn.Module): + def __init__(self, ch_in, ch_out, drop_out=False): + super(conv_block, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), + nn.BatchNorm2d(ch_out), + nn.ReLU(inplace=True), + ) + self.dropout = drop_out + + def forward(self, x): + x = self.conv(x) + if self.dropout: + x = nn.Dropout2d(0.5)(x) + return x + + +# # UpCat(nn.Module) for U-net UP convolution +class UpCat(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True): + super(UpCat, self).__init__() + if is_deconv: + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv? + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = torch.cat([inputs, outputs], dim=1) + + return out + + +# # UpCatconv(nn.Module) for up convolution +class UpCatconv(nn.Module): + def __init__(self, in_feat, out_feat, is_deconv=True, drop_out=False): + super(UpCatconv, self).__init__() + + if is_deconv: + self.conv = conv_block(in_feat, out_feat, drop_out=drop_out) + self.up = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2) + else: + self.conv = conv_block(in_feat + out_feat, out_feat, drop_out=drop_out) + self.up = nn.Upsample(scale_factor=2, mode='bilinear') + + def forward(self, inputs, down_outputs): + # TODO: Upsampling required after deconv + outputs = self.up(down_outputs) + offset = inputs.size()[3] - outputs.size()[3] + if offset == 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2]), out=None).unsqueeze( + 3).cuda() + outputs = torch.cat([outputs, addition], dim=3) + elif offset > 1: + addition = torch.rand((outputs.size()[0], outputs.size()[1], outputs.size()[2], offset), out=None).cuda() + outputs = torch.cat([outputs, addition], dim=3) + out = self.conv(torch.cat([inputs, outputs], dim=1)) + + return out + + +class UnetDsv3(nn.Module): + def __init__(self, in_size, out_size, scale_factor): + super(UnetDsv3, self).__init__() + self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0), + nn.Upsample(size=scale_factor, mode='bilinear'), ) + + def forward(self, input): + return self.dsv(input) + + +###### Intial weights ##### +def weights_init_normal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + nn.init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + #print(classname) + if classname.find('Conv') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + nn.init.orthogonal(m.weight.data, gain=1) + elif classname.find('BatchNorm') != -1: + nn.init.normal(m.weight.data, 1.0, 0.02) + nn.init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + #print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +###### For attention ###### +class _GridAttentionBlockND(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(_GridAttentionBlockND, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] + + # Downsampling rate for the input featuremap + if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor + elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) + else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # Output transform + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=True) + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=(1, 1), stride=1, padding=0, bias=True) + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + # Define the operation + if mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concatenation_debug': + self.operation_function = self._concatenation_debug + elif mode == 'concatenation_residual': + self.operation_function = self._concatenation_residual + else: + raise NotImplementedError('Unknown operation function.') + + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + def _concatenation_debug(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.softplus(theta_x + phi_g) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + sigm_psi_f = F.sigmoid(self.psi(f)) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + + def _concatenation_residual(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) + # phi => (b, g_d) -> (b, i_c) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') + # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + f = F.relu(theta_x + phi_g, inplace=True) + + # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) + f = self.psi(f).view(batch_size, 1, -1) + sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) + + # upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2, 2)): + super(GridAttentionBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + + +class GridAttentionBlock3D(_GridAttentionBlockND): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(2,2,2)): + super(GridAttentionBlock3D, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + ) + +class _GridAttentionBlockND_TORR(nn.Module): + def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True, use_W=True, use_phi=True, use_theta=True, use_psi=True, nonlinearity1='relu'): + super(_GridAttentionBlockND_TORR, self).__init__() + + assert dimension in [2, 3] + assert mode in ['concatenation', 'concatenation_softmax', + 'concatenation_sigmoid', 'concatenation_mean', + 'concatenation_range_normalise', 'concatenation_mean_flow'] + + # Default parameter set + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, tuple) else tuple([sub_sample_factor])*dimension + self.sub_sample_kernel_size = self.sub_sample_factor + + # Number of channels (pixel dimensions) + self.in_channels = in_channels + self.gating_channels = gating_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + bn = nn.BatchNorm3d + self.upsample_mode = 'trilinear' + elif dimension == 2: + conv_nd = nn.Conv2d + bn = nn.BatchNorm2d + self.upsample_mode = 'bilinear' + else: + raise NotImplemented + + # initialise id functions + # Theta^T * x_ij + Phi^T * gating_signal + bias + self.W = lambda x: x + self.theta = lambda x: x + self.psi = lambda x: x + self.phi = lambda x: x + self.nl1 = lambda x: x + + if use_W: + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), + bn(self.in_channels), + ) + else: + self.W = conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0) + + if use_theta: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_phi: + self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, + kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) + + + if use_psi: + self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) + + + if nonlinearity1: + if nonlinearity1 == 'relu': + self.nl1 = lambda x: F.relu(x, inplace=True) + + if 'concatenation' in mode: + self.operation_function = self._concatenation + else: + raise NotImplementedError('Unknown operation function.') + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + + if use_psi and self.mode == 'concatenation_sigmoid': + nn.init.constant(self.psi.bias.data, 3.0) + + if use_psi and self.mode == 'concatenation_softmax': + nn.init.constant(self.psi.bias.data, 10.0) + + # if use_psi and self.mode == 'concatenation_mean': + # nn.init.constant(self.psi.bias.data, 3.0) + + # if use_psi and self.mode == 'concatenation_range_normalise': + # nn.init.constant(self.psi.bias.data, 3.0) + + parallel = False + if parallel: + if use_W: self.W = nn.DataParallel(self.W) + if use_phi: self.phi = nn.DataParallel(self.phi) + if use_psi: self.psi = nn.DataParallel(self.psi) + if use_theta: self.theta = nn.DataParallel(self.theta) + + def forward(self, x, g): + ''' + :param x: (b, c, t, h, w) + :param g: (b, g_d) + :return: + ''' + + output = self.operation_function(x, g) + return output + + def _concatenation(self, x, g): + input_size = x.size() + batch_size = input_size[0] + assert batch_size == g.size(0) + + ############################# + # compute compatibility score + + # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) + # phi => (b, c, t, h, w) -> (b, i_c, t, h, w) + theta_x = self.theta(x) + theta_x_size = theta_x.size() + + # nl(theta.x + phi.g + bias) -> f = (b, i_c, t/s1, h/s2, w/s3) + phi_g = F.upsample(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) + + f = theta_x + phi_g + f = self.nl1(f) + + psi_f = self.psi(f) + + ############################################ + # normalisation -- scale compatibility score + # psi^T . f -> (b, 1, t/s1, h/s2, w/s3) + if self.mode == 'concatenation_softmax': + sigm_psi_f = F.softmax(psi_f.view(batch_size, 1, -1), dim=2) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean': + psi_f_flat = psi_f.view(batch_size, 1, -1) + psi_f_sum = torch.sum(psi_f_flat, dim=2)#clamp(1e-6) + psi_f_sum = psi_f_sum[:,:,None].expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_mean_flow': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_min = psi_f_flat.min(dim=2)[0].view(ss[0],ss[1],1) + psi_f_flat = psi_f_flat - psi_f_min + psi_f_sum = torch.sum(psi_f_flat, dim=2).view(ss[0],ss[1],1).expand_as(psi_f_flat) + + sigm_psi_f = psi_f_flat / psi_f_sum + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + elif self.mode == 'concatenation_range_normalise': + psi_f_flat = psi_f.view(batch_size, 1, -1) + ss = psi_f_flat.shape + psi_f_max = torch.max(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + psi_f_min = torch.min(psi_f_flat, dim=2)[0].view(ss[0], ss[1], 1) + + sigm_psi_f = (psi_f_flat - psi_f_min) / (psi_f_max - psi_f_min).expand_as(psi_f_flat) + sigm_psi_f = sigm_psi_f.view(batch_size, 1, *theta_x_size[2:]) + + elif self.mode == 'concatenation_sigmoid': + sigm_psi_f = F.sigmoid(psi_f) + else: + raise NotImplementedError + + # sigm_psi_f is attention map! upsample the attentions and multiply + sigm_psi_f = F.upsample(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) + y = sigm_psi_f.expand_as(x) * x + W_y = self.W(y) + + return W_y, sigm_psi_f + + +class GridAttentionBlock2D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1), bn_layer=True, + use_W=True, use_phi=True, use_theta=True, use_psi=True, + nonlinearity1='relu'): + super(GridAttentionBlock2D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer, + use_W=use_W, + use_phi=use_phi, + use_theta=use_theta, + use_psi=use_psi, + nonlinearity1=nonlinearity1) + + +class GridAttentionBlock3D_TORR(_GridAttentionBlockND_TORR): + def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', + sub_sample_factor=(1,1,1), bn_layer=True): + super(GridAttentionBlock3D_TORR, self).__init__(in_channels, + inter_channels=inter_channels, + gating_channels=gating_channels, + dimension=3, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class MultiAttentionBlock(nn.Module): + def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): + super(MultiAttentionBlock, self).__init__() + self.gate_block_1 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.gate_block_2 = GridAttentionBlock2D(in_channels=in_size, gating_channels=gate_size, + inter_channels=inter_size, mode=nonlocal_mode, + sub_sample_factor=sub_sample_factor) + self.combine_gates = nn.Sequential(nn.Conv2d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(in_size), + nn.ReLU(inplace=True)) + + # initialise the blocks + for m in self.children(): + if m.__class__.__name__.find('GridAttentionBlock2D') != -1: continue + init_weights(m, init_type='kaiming') + + def forward(self, input, gating_signal): + gate_1, attention_1 = self.gate_block_1(input, gating_signal) + gate_2, attention_2 = self.gate_block_2(input, gating_signal) + + return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/__init__.py b/pymic/net/net2d/trans2d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net/net2d/trans2d/swinunet.py b/pymic/net/net2d/trans2d/swinunet.py new file mode 100644 index 0000000..f35539a --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import copy +import numpy as np +import torch +import torch.nn as nn + +from pymic.net.net2d.trans2d.swinunet_sys import SwinTransformerSys + +class SwinUNet(nn.Module): + """ + Implementatin of Swin-UNet. + + * Reference: Hu Cao, Yueyue Wang et al: + Swin-Unet: Unet-Like Pure Transformer for Medical Image Segmentation. + `ECCV 2022 Workshops. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 224x224. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [224, 224]. + :param class_num: (int) The class number for segmentation task. + """ + def __init__(self, params): + super(SwinUNet, self).__init__() + img_size = params['img_size'] + if(isinstance(img_size, tuple) or isinstance(img_size, list)): + img_size = img_size[0] + self.num_classes = params['class_num'] + self.swin_unet = SwinTransformerSys(img_size = img_size, num_classes=self.num_classes) + # self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, + # patch_size=config.MODEL.SWIN.PATCH_SIZE, + # in_chans=config.MODEL.SWIN.IN_CHANS, + # num_classes=self.num_classes, + # embed_dim=config.MODEL.SWIN.EMBED_DIM, + # depths=config.MODEL.SWIN.DEPTHS, + # num_heads=config.MODEL.SWIN.NUM_HEADS, + # window_size=config.MODEL.SWIN.WINDOW_SIZE, + # mlp_ratio=config.MODEL.SWIN.MLP_RATIO, + # qkv_bias=config.MODEL.SWIN.QKV_BIAS, + # qk_scale=config.MODEL.SWIN.QK_SCALE, + # drop_rate=config.MODEL.DROP_RATE, + # drop_path_rate=config.MODEL.DROP_PATH_RATE, + # ape=config.MODEL.SWIN.APE, + # patch_norm=config.MODEL.SWIN.PATCH_NORM, + # use_checkpoint=config.TRAIN.USE_CHECKPOINT) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + logits = self.swin_unet(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, config): + pretrained_path = config.MODEL.PRETRAIN_CKPT + if pretrained_path is not None: + print("pretrained_path:{}".format(pretrained_path)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pretrained_dict = torch.load(pretrained_path, map_location=device) + if "model" not in pretrained_dict: + print("---start load pretrained modle by splitting---") + pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} + for k in list(pretrained_dict.keys()): + if "output" in k: + print("delete key:{}".format(k)) + del pretrained_dict[k] + msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) + # print(msg) + return + pretrained_dict = pretrained_dict['model'] + print("---start load pretrained modle of swin encoder---") + + model_dict = self.swin_unet.state_dict() + full_dict = copy.deepcopy(pretrained_dict) + for k, v in pretrained_dict.items(): + if "layers." in k: + current_layer_num = 3-int(k[7:8]) + current_k = "layers_up." + str(current_layer_num) + k[8:] + full_dict.update({current_k:v}) + for k in list(full_dict.keys()): + if k in model_dict: + if full_dict[k].shape != model_dict[k].shape: + print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) + del full_dict[k] + + msg = self.swin_unet.load_state_dict(full_dict, strict=False) + # print(msg) + else: + print("none pretrain") + + +if __name__ == "__main__": + params = {'img_size': [224, 224], + 'class_num': 2} + net = SwinUNet(params) + net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/swinunet_sys.py b/pymic/net/net2d/trans2d/swinunet_sys.py new file mode 100644 index 0000000..a6e3552 --- /dev/null +++ b/pymic/net/net2d/trans2d/swinunet_sys.py @@ -0,0 +1,749 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/HuCaoFighting/Swin-Unet + +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from einops import rearrange +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +class PatchExpand(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4) + x = x.view(B,-1,C//4) + x= self.norm(x) + + return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16*dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2)) + x = x.view(B,-1,self.output_dim) + x= self.norm(x) + + return x + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerSys(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, + embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, final_upsample="expand_first", **kwargs): + super().__init__() + + print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths, + depths_decoder,drop_path_rate,num_classes)) + + self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features_up = int(embed_dim * 2) + self.mlp_ratio = mlp_ratio + self.final_upsample = final_upsample + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build encoder and bottleneck layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + # build decoder layers + self.layers_up = nn.ModuleList() + self.concat_back_dim = nn.ModuleList() + for i_layer in range(self.num_layers): + concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)), + int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity() + if i_layer ==0 : + layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer) + else: + layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), + input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)), + patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), + depth=depths[(self.num_layers-1-i_layer)], + num_heads=num_heads[(self.num_layers-1-i_layer)], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])], + norm_layer=norm_layer, + upsample=PatchExpand if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers_up.append(layer_up) + self.concat_back_dim.append(concat_linear) + + self.norm = norm_layer(self.num_features) + self.norm_up= norm_layer(self.embed_dim) + + if self.final_upsample == "expand_first": + print("---final upsample expand_first---") + self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim) + self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + #Encoder and Bottleneck + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + x_downsample = [] + + for layer in self.layers: + x_downsample.append(x) + x = layer(x) + + x = self.norm(x) # B L C + + return x, x_downsample + + #Dencoder and Skip connection + def forward_up_features(self, x, x_downsample): + for inx, layer_up in enumerate(self.layers_up): + if inx == 0: + x = layer_up(x) + else: + x = torch.cat([x,x_downsample[3-inx]],-1) + x = self.concat_back_dim[inx](x) + x = layer_up(x) + + x = self.norm_up(x) # B L C + + return x + + def up_x4(self, x): + H, W = self.patches_resolution + B, L, C = x.shape + assert L == H*W, "input features has wrong size" + + if self.final_upsample=="expand_first": + x = self.up(x) + x = x.view(B,4*H,4*W,-1) + x = x.permute(0,3,1,2) #B,C,H,W + x = self.output(x) + + return x + + def forward(self, x): + x, x_downsample = self.forward_features(x) + x = self.forward_up_features(x,x_downsample) + x = self.up_x4(x) + + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + flops += self.num_features * self.num_classes + return \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet.py b/pymic/net/net2d/trans2d/transunet.py new file mode 100644 index 0000000..9db5d2d --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet.py @@ -0,0 +1,491 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +import copy +# import logging +import math +import torch +import torch.nn as nn +from torch.nn import Dropout, Softmax, Linear, Conv2d, LayerNorm +from torch.nn.modules.utils import _pair + +import numpy as np +from scipy import ndimage +from os.path import join as pjoin +import pymic.net.net2d.trans2d.transunet_cfg as configs +from pymic.net.net2d.trans2d.transunet_resnet import ResNetV2 + + +VIT_CONFIGS = { + 'ViT-B_16': configs.get_b16_config(), + 'ViT-B_32': configs.get_b32_config(), + 'ViT-L_16': configs.get_l16_config(), + 'ViT-L_32': configs.get_l32_config(), + 'ViT-H_14': configs.get_h14_config(), + 'R50-ViT-B_16': configs.get_r50_b16_config(), + 'R50-ViT-L_16': configs.get_r50_l16_config(), + 'testing': configs.get_testing(), +} + +ATTENTION_Q = "MultiHeadDotProductAttention_1/query" +ATTENTION_K = "MultiHeadDotProductAttention_1/key" +ATTENTION_V = "MultiHeadDotProductAttention_1/value" +ATTENTION_OUT = "MultiHeadDotProductAttention_1/out" +FC_0 = "MlpBlock_3/Dense_0" +FC_1 = "MlpBlock_3/Dense_1" +ATTENTION_NORM = "LayerNorm_0" +MLP_NORM = "LayerNorm_2" + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} + +class Attention(nn.Module): + def __init__(self, config, vis): + super(Attention, self).__init__() + self.vis = vis + self.num_attention_heads = config.transformer["num_heads"] + self.attention_head_size = int(config.hidden_size / self.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = Linear(config.hidden_size, self.all_head_size) + self.key = Linear(config.hidden_size, self.all_head_size) + self.value = Linear(config.hidden_size, self.all_head_size) + + self.out = Linear(config.hidden_size, config.hidden_size) + self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"]) + self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"]) + + self.softmax = Softmax(dim=-1) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + attention_probs = self.softmax(attention_scores) + weights = attention_probs if self.vis else None + attention_probs = self.attn_dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + attention_output = self.out(context_layer) + attention_output = self.proj_dropout(attention_output) + return attention_output, weights + +class Mlp(nn.Module): + def __init__(self, config): + super(Mlp, self).__init__() + self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) + self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) + self.act_fn = ACT2FN["gelu"] + self.dropout = Dropout(config.transformer["dropout_rate"]) + + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + +class Embeddings(nn.Module): + """Construct the embeddings from patch, position embeddings. + """ + def __init__(self, config, img_size, in_channels=3): + super(Embeddings, self).__init__() + self.hybrid = None + self.config = config + img_size = _pair(img_size) + + if config.patches.get("grid") is not None: # ResNet + grid_size = config.patches["grid"] + patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) + patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) + n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) + self.hybrid = True + else: + patch_size = _pair(config.patches["size"]) + n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.hybrid = False + + if self.hybrid: + self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor) + in_channels = self.hybrid_model.width * 16 + self.patch_embeddings = Conv2d(in_channels=in_channels, + out_channels=config.hidden_size, + kernel_size=patch_size, + stride=patch_size) + self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size)) + + self.dropout = Dropout(config.transformer["dropout_rate"]) + + + def forward(self, x): + if self.hybrid: + x, features = self.hybrid_model(x) + else: + features = None + x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) + x = x.flatten(2) + x = x.transpose(-1, -2) # (B, n_patches, hidden) + + embeddings = x + self.position_embeddings + embeddings = self.dropout(embeddings) + return embeddings, features + +class Block(nn.Module): + def __init__(self, config, vis): + super(Block, self).__init__() + self.hidden_size = config.hidden_size + self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6) + self.ffn = Mlp(config) + self.attn = Attention(config, vis) + + def forward(self, x): + h = x + x = self.attention_norm(x) + x, weights = self.attn(x) + x = x + h + + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + return x, weights + + def load_from(self, weights, n_block): + ROOT = f"Transformer/encoderblock_{n_block}" + with torch.no_grad(): + query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t() + key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t() + value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t() + out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t() + + query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1) + key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1) + value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1) + out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1) + + self.attn.query.weight.copy_(query_weight) + self.attn.key.weight.copy_(key_weight) + self.attn.value.weight.copy_(value_weight) + self.attn.out.weight.copy_(out_weight) + self.attn.query.bias.copy_(query_bias) + self.attn.key.bias.copy_(key_bias) + self.attn.value.bias.copy_(value_bias) + self.attn.out.bias.copy_(out_bias) + + mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t() + mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t() + mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t() + mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t() + + self.ffn.fc1.weight.copy_(mlp_weight_0) + self.ffn.fc2.weight.copy_(mlp_weight_1) + self.ffn.fc1.bias.copy_(mlp_bias_0) + self.ffn.fc2.bias.copy_(mlp_bias_1) + + self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")])) + self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")])) + self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")])) + self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")])) + +class Encoder(nn.Module): + def __init__(self, config, vis): + super(Encoder, self).__init__() + self.vis = vis + self.layer = nn.ModuleList() + self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6) + for _ in range(config.transformer["num_layers"]): + layer = Block(config, vis) + self.layer.append(copy.deepcopy(layer)) + + def forward(self, hidden_states): + attn_weights = [] + for layer_block in self.layer: + hidden_states, weights = layer_block(hidden_states) + if self.vis: + attn_weights.append(weights) + encoded = self.encoder_norm(hidden_states) + return encoded, attn_weights + +class Transformer(nn.Module): + def __init__(self, config, img_size, vis): + super(Transformer, self).__init__() + self.embeddings = Embeddings(config, img_size=img_size) + self.encoder = Encoder(config, vis) + + def forward(self, input_ids): + embedding_output, features = self.embeddings(input_ids) + encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) + return encoded, attn_weights, features + +class Conv2dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=not (use_batchnorm), + ) + relu = nn.ReLU(inplace=True) + + bn = nn.BatchNorm2d(out_channels) + + super(Conv2dReLU, self).__init__(conv, bn, relu) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + self.conv1 = Conv2dReLU( + in_channels + skip_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.conv2 = Conv2dReLU( + out_channels, + out_channels, + kernel_size=3, + padding=1, + use_batchnorm=use_batchnorm, + ) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, x, skip=None): + x = self.up(x) + if skip is not None: + x = torch.cat([x, skip], dim=1) + x = self.conv1(x) + x = self.conv2(x) + return x + + +class SegmentationHead(nn.Sequential): + + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() + super().__init__(conv2d, upsampling) + + +class DecoderCup(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + head_channels = 512 + self.conv_more = Conv2dReLU( + config.hidden_size, + head_channels, + kernel_size=3, + padding=1, + use_batchnorm=True, + ) + decoder_channels = config.decoder_channels + in_channels = [head_channels] + list(decoder_channels[:-1]) + out_channels = decoder_channels + + if self.config.n_skip != 0: + skip_channels = self.config.skip_channels + for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip + skip_channels[3-i]=0 + + else: + skip_channels=[0,0,0,0] + + blocks = [ + DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) + ] + self.blocks = nn.ModuleList(blocks) + + def forward(self, hidden_states, features=None): + B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) + h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) + x = hidden_states.permute(0, 2, 1) + x = x.contiguous().view(B, hidden, h, w) + x = self.conv_more(x) + for i, decoder_block in enumerate(self.blocks): + if features is not None: + skip = features[i] if (i < self.config.n_skip) else None + else: + skip = None + x = decoder_block(x, skip=skip) + return x + +class TransUNet(nn.Module): + """ + Implementatin of TransUNet. + + * Reference: Jieneng Chen, Yongyi Lu et al: + TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. + `Arxiv 2021. `_ + + Note that the input channel can only be 1 or 3, and the input image size should be 256x256. + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param img_size: (tuple) The input image size, should be [256, 256]. + :param class_num: (int) The class number for segmentation task. + :param vit_name: (string) The name for vit backbone. It can be one of the following: 'ViT-B_16', + 'ViT-B_32','ViT-L_16', 'ViT-L_32', 'ViT-H_14'. 'R50-ViT-B_16', 'R50-ViT-L_16'. + By default, it is 'R50-ViT-B_16'. + """ + def __init__(self, params): + super(TransUNet, self).__init__() + vit_name = params.get("vit_name", 'R50-ViT-B_16') + img_size = params['img_size'] + vis = params.get("vis", False) + self.config = VIT_CONFIGS[vit_name] + self.num_classes = params['class_num'] + self.zero_head = params.get("zero_head", False) + + self.classifier = self.config.classifier + self.transformer = Transformer(self.config, img_size, vis) + self.decoder = DecoderCup(self.config) + self.segmentation_head = SegmentationHead( + in_channels=self.config['decoder_channels'][-1], + out_channels=self.num_classes, + kernel_size=3, + ) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + if x.size()[1] == 1: + x = x.repeat(1,3,1,1) + elif(x.size()[1] !=3): + raise ValueError("The input channel number should be 1 or 3 for TransUNet") + x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden) + x = self.decoder(x, features) + logits = self.segmentation_head(x) + + if(len(x_shape) == 5): + new_shape = [N, D] + list(logits.shape)[1:] + logits = torch.reshape(logits, new_shape) + logits = torch.transpose(logits, 1, 2) + + return logits + + def load_from(self, weights): + with torch.no_grad(): + + res_weight = weights + self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True)) + self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"])) + + self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"])) + self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"])) + + posemb = np2th(weights["Transformer/posembed_input/pos_embedding"]) + + posemb_new = self.transformer.embeddings.position_embeddings + if posemb.size() == posemb_new.size(): + self.transformer.embeddings.position_embeddings.copy_(posemb) + elif posemb.size()[1]-1 == posemb_new.size()[1]: + posemb = posemb[:, 1:] + self.transformer.embeddings.position_embeddings.copy_(posemb) + else: + # logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size())) + ntok_new = posemb_new.size(1) + if self.classifier == "seg": + _, posemb_grid = posemb[:, :1], posemb[0, 1:] + gs_old = int(np.sqrt(len(posemb_grid))) + gs_new = int(np.sqrt(ntok_new)) + print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new)) + posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) + zoom = (gs_new / gs_old, gs_new / gs_old, 1) + posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np + posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) + posemb = posemb_grid + self.transformer.embeddings.position_embeddings.copy_(np2th(posemb)) + + # Encoder whole + for bname, block in self.transformer.encoder.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, n_block=uname) + + if self.transformer.embeddings.hybrid: + self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True)) + gn_weight = np2th(res_weight["gn_root/scale"]).view(-1) + gn_bias = np2th(res_weight["gn_root/bias"]).view(-1) + self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight) + self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias) + + for bname, block in self.transformer.embeddings.hybrid_model.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(res_weight, n_block=bname, n_unit=uname) + +if __name__ == "__main__": + params = {'img_size': [256, 256], + 'class_num': 2} + net = TransUNet(params) + net.double() + + for c in [1,3]: + x = np.random.rand(4, c, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_cfg.py b/pymic/net/net2d/trans2d/transunet_cfg.py new file mode 100644 index 0000000..aab62d4 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_cfg.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +import ml_collections + +def get_b16_config(): + """Returns the ViT-B/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 768 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 3072 + config.transformer.num_heads = 12 + config.transformer.num_layers = 12 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + + config.classifier = 'seg' + config.representation_size = None + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_16.npz' + config.patch_size = 16 + + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_testing(): + """Returns a minimal configuration for testing.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 1 + config.transformer.num_heads = 1 + config.transformer.num_layers = 1 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + return config + +def get_r50_b16_config(): + """Returns the Resnet50 + ViT-B/16 configuration.""" + config = get_b16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.n_skip = 3 + config.activation = 'softmax' + + return config + + +def get_b32_config(): + """Returns the ViT-B/32 configuration.""" + config = get_b16_config() + config.patches.size = (32, 32) + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-B_32.npz' + return config + + +def get_l16_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (16, 16)}) + config.hidden_size = 1024 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 4096 + config.transformer.num_heads = 16 + config.transformer.num_layers = 24 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.representation_size = None + + # custom + config.classifier = 'seg' + config.resnet_pretrained_path = None + config.pretrained_path = '../model/vit_checkpoint/imagenet21k/ViT-L_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_r50_l16_config(): + """Returns the Resnet50 + ViT-L/16 configuration. customized """ + config = get_l16_config() + config.patches.grid = (16, 16) + config.resnet = ml_collections.ConfigDict() + config.resnet.num_layers = (3, 4, 9) + config.resnet.width_factor = 1 + + config.classifier = 'seg' + config.resnet_pretrained_path = '../model/vit_checkpoint/imagenet21k/R50+ViT-B_16.npz' + config.decoder_channels = (256, 128, 64, 16) + config.skip_channels = [512, 256, 64, 16] + config.n_classes = 2 + config.activation = 'softmax' + return config + + +def get_l32_config(): + """Returns the ViT-L/32 configuration.""" + config = get_l16_config() + config.patches.size = (32, 32) + return config + + +def get_h14_config(): + """Returns the ViT-L/16 configuration.""" + config = ml_collections.ConfigDict() + config.patches = ml_collections.ConfigDict({'size': (14, 14)}) + config.hidden_size = 1280 + config.transformer = ml_collections.ConfigDict() + config.transformer.mlp_dim = 5120 + config.transformer.num_heads = 16 + config.transformer.num_layers = 32 + config.transformer.attention_dropout_rate = 0.0 + config.transformer.dropout_rate = 0.1 + config.classifier = 'token' + config.representation_size = None + + return config \ No newline at end of file diff --git a/pymic/net/net2d/trans2d/transunet_resnet.py b/pymic/net/net2d/trans2d/transunet_resnet.py new file mode 100644 index 0000000..144a268 --- /dev/null +++ b/pymic/net/net2d/trans2d/transunet_resnet.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +""" +code adapted from: https://github.com/Beckschen/TransUNet +""" +from __future__ import print_function, division + +from os.path import join as pjoin +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def np2th(weights, conv=False): + """Possibly convert HWIO to OIHW.""" + if conv: + weights = weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(weights) + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-5) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv1 = conv1x1(cin, cmid, bias=False) + self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) + self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) + self.conv3 = conv1x1(cmid, cout, bias=False) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride, bias=False) + self.gn_proj = nn.GroupNorm(cout, cout) + + def forward(self, x): + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(x) + residual = self.gn_proj(residual) + + # Unit's branch + y = self.relu(self.gn1(self.conv1(x))) + y = self.relu(self.gn2(self.conv2(y))) + y = self.gn3(self.conv3(y)) + + y = self.relu(residual + y) + return y + + def load_from(self, weights, n_block, n_unit): + conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True) + conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True) + conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True) + + gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")]) + gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")]) + + gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")]) + gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")]) + + gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")]) + gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")]) + + self.conv1.weight.copy_(conv1_weight) + self.conv2.weight.copy_(conv2_weight) + self.conv3.weight.copy_(conv3_weight) + + self.gn1.weight.copy_(gn1_weight.view(-1)) + self.gn1.bias.copy_(gn1_bias.view(-1)) + + self.gn2.weight.copy_(gn2_weight.view(-1)) + self.gn2.bias.copy_(gn2_bias.view(-1)) + + self.gn3.weight.copy_(gn3_weight.view(-1)) + self.gn3.bias.copy_(gn3_bias.view(-1)) + + if hasattr(self, 'downsample'): + proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True) + proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")]) + proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")]) + + self.downsample.weight.copy_(proj_conv_weight) + self.gn_proj.weight.copy_(proj_gn_weight.view(-1)) + self.gn_proj.bias.copy_(proj_gn_bias.view(-1)) + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor): + super().__init__() + width = int(64 * width_factor) + self.width = width + + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), + ('gn', nn.GroupNorm(32, width, eps=1e-6)), + ('relu', nn.ReLU(inplace=True)), + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)) + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], + ))), + ])) + + def forward(self, x): + features = [] + b, c, in_size, _ = x.size() + x = self.root(x) + features.append(x) + x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) + for i in range(len(self.body)-1): + x = self.body[i](x) + right_size = int(in_size / 4 / (i+1)) + if x.size()[2] != right_size: + pad = right_size - x.size()[2] + assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size) + feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) + feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] + else: + feat = x + features.append(feat) + x = self.body[-1](x) + return x, features[::-1] \ No newline at end of file diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 9acc0ad..be69f0d 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division +import logging import torch import torch.nn as nn import numpy as np -from torch.nn.functional import interpolate class ConvBlock(nn.Module): """ @@ -56,22 +56,32 @@ class UpBlock(nn.Module): :param in_channels2: (int) Channel number of low-level features. :param out_channels: (int) Output channel number. :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling (by default). - If False, deconvolution is used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (`Nearest`), 2 (`Bilinear`), 3 (`Bicubic`). The default value + is 2 (`Bilinear`). """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + if(isinstance(up_mode, int)): + up_mode_values = ["transconv", "nearest", "bilinear", "bicubic"] + if(up_mode > 3): + raise ValueError("The upsample mode should be 0-3, but {0:} is given.".format(up_mode)) + self.up_mode = up_mode_values[up_mode] else: + self.up_mode = up_mode.lower() + + if (self.up_mode == "transconv"): self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + else: + self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) + if(self.up_mode == "nearest"): + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode) + else: + self.up = nn.Upsample(scale_factor=2, mode=self.up_mode, align_corners=True) self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p) def forward(self, x1, x2): - if self.bilinear: + if self.up_mode != "transconv": x1 = self.conv1x1(x1) x1 = self.up(x1) x = torch.cat([x2, x1], dim=1) @@ -129,8 +139,10 @@ class Decoder(nn.Module): :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). + :param multiscale_pred: (bool) Get multi-scale prediction. """ def __init__(self, params): super(Decoder, self).__init__() @@ -139,17 +151,27 @@ def __init__(self, params): self.ft_chns = self.params['feature_chns'] self.dropout = self.params['dropout'] self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] + self.up_mode = self.params.get('up_mode', 2) + self.mul_pred = self.params.get('multiscale_pred', False) assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + if(self.mul_pred and (self.training or self.mul_infer)): + self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) + self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) + self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + self.stage = 'train' + + def set_stage(self, stage): + self.stage = stage + def forward(self, x): if(len(self.ft_chns) == 5): assert(len(x) == 5) @@ -163,6 +185,11 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) + if(self.mul_pred and self.stage == 'train'): + output1 = self.out_conv1(x_d1) + output2 = self.out_conv2(x_d2) + output3 = self.out_conv3(x_d3) + output = [output, output1, output2, output3] return output class UNet2D(nn.Module): @@ -180,43 +207,43 @@ class UNet2D(nn.Module): following fields: :param in_chns: (int) Input channel number. + :param class_num: (int) The class number for segmentation task. + + Optional parameters: + :param feature_chns: (list) Feature channel for each resolution level. The length should be 4 or 5, such as [16, 32, 64, 128, 256]. :param dropout: (list) The dropout ratio for each resolution level. The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + :param up_mode: (string or int) The mode for upsampling. The allowed values are: + 0 (or `TransConv`), 1 (or `Nearest`), 2 (or `Bilinear`), 3 (or `Bicubic`). + The default value is 2 (or `Bilinear`). :param multiscale_pred: (bool) Get multiscale prediction. """ def __init__(self, params): super(UNet2D, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - self.mul_pred = self.params['multiscale_pred'] + params = self.get_default_parameters(params) + for p in params: + print(p, params[p]) + self.encoder = Encoder(params) + self.decoder = Decoder(params) + + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - if(self.mul_pred): - self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1) - self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1) - self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1) + def set_stage(self, stage): + self.stage = stage + self.decoder.set_stage(stage) def forward(self, x): x_shape = list(x.shape) @@ -226,51 +253,15 @@ def forward(self, x): x = torch.transpose(x, 1, 2) x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3 = self.up1(x4, x3) - else: - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - if(self.mul_pred): - output1 = self.out_conv1(x_d1) - output2 = self.out_conv2(x_d2) - output3 = self.out_conv3(x_d3) - output = [output, output1, output2, output3] - - if(len(x_shape) == 5): + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + if(isinstance(output, (list,tuple))): for i in range(len(output)): new_shape = [N, D] + list(output[i].shape)[1:] output[i] = torch.transpose(torch.reshape(output[i], new_shape), 1, 2) - elif(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True, - 'multiscale_pred': False} - Net = UNet2D(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) + else: + new_shape = [N, D] + list(output.shape)[1:] + output = torch.transpose(torch.reshape(output, new_shape), 1, 2) + + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_attention.py b/pymic/net/net2d/unet2d_attention.py index 6afdfdc..36faec8 100644 --- a/pymic/net/net2d/unet2d_attention.py +++ b/pymic/net/net2d/unet2d_attention.py @@ -4,14 +4,7 @@ import torch import torch.nn as nn from pymic.net.net2d.unet2d import * -""" -A Reimplementation of the attention U-Net paper: - Ozan Oktay, Jo Schlemper et al.: - Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. -Note that there are some modifications from the original paper, such as -the use of batch normalization, dropout, and leaky relu here. -""" class AttentionGateBlock(nn.Module): def __init__(self, chns_l, chns_h): """ @@ -80,6 +73,14 @@ def forward(self, x1, x2): return self.conv(x) class AttentionUNet2D(UNet2D): + """ + A Reimplementation of the attention U-Net paper: + Ozan Oktay, Jo Schlemper et al.: + Attentin U-Net: Looking Where to Look for the Pancreas. MIDL, 2018. + + Note that there are some modifications from the original paper, such as + the use of batch normalization, dropout, and leaky relu here. + """ def __init__(self, params): super(AttentionUNet2D, self).__init__(params) self.up1 = UpBlockWithAttention(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) diff --git a/pymic/net/net2d/unet2d_canet.py b/pymic/net/net2d/unet2d_canet.py new file mode 100644 index 0000000..defcb60 --- /dev/null +++ b/pymic/net/net2d/unet2d_canet.py @@ -0,0 +1,690 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from pymic.net.net2d.canet_module import * + + +def conv3x3(in_planes, out_planes, stride=1, bias=False, group=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, groups=group, bias=bias) + + +class SE_Conv_Block(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_out=False): + super(SE_Conv_Block, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes * 2) + self.bn2 = nn.BatchNorm2d(planes * 2) + self.conv3 = conv3x3(planes * 2, planes) + self.bn3 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.dropout = drop_out + + self.fc1 = nn.Linear(in_features=planes * 2, out_features=round(planes / 2)) + self.fc2 = nn.Linear(in_features=round(planes / 2), out_features=planes * 2) + self.sigmoid = nn.Sigmoid() + + self.downchannel = None + if inplanes != planes: + self.downchannel = nn.Sequential(nn.Conv2d(inplanes, planes * 2, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * 2),) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downchannel is not None: + residual = self.downchannel(x) + + original_out = out + out1 = out + # For global average pool + out = F.adaptive_avg_pool2d(out, (1,1)) + out = out.view(out.size(0), -1) + out = self.fc1(out) + out = self.relu(out) + out = self.fc2(out) + out = self.sigmoid(out) + out = out.view(out.size(0), out.size(1), 1, 1) + avg_att = out + out = out * original_out + # For global maximum pool + out1 = F.adaptive_max_pool2d(out1, (1,1)) + out1 = out1.view(out1.size(0), -1) + out1 = self.fc1(out1) + out1 = self.relu(out1) + out1 = self.fc2(out1) + out1 = self.sigmoid(out1) + out1 = out1.view(out1.size(0), out1.size(1), 1, 1) + max_att = out1 + out1 = out1 * original_out + + att_weight = avg_att + max_att + out += out1 + out += residual + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out, att_weight + +# # CBAM Convolutional block attention module +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, + relu=True, bn=True, bias=False): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, + dilation=dilation, groups=groups, bias=bias) + self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None + self.relu = nn.ReLU() if relu else None + + def forward(self, x): + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.relu is not None: + x = self.relu(x) + return x + + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type == 'avg': + avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(avg_pool) + elif pool_type == 'max': + max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(max_pool) + elif pool_type == 'lp': + lp_pool = F.lp_pool2d(x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp(lp_pool) + elif pool_type == 'lse': + # LSE pool only + lse_pool = logsumexp_2d(x) + channel_att_raw = self.mlp(lse_pool) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + # scalecoe = F.sigmoid(channel_att_sum) + # print("channel att_sum", channel_att_sum.shape) + # channel_att_sum = channel_att_sum.reshape(channel_att_sum.shape[0], 4, 4) + # avg_weight = torch.mean(channel_att_sum, dim=2).unsqueeze(2) + # avg_weight = avg_weight.expand(channel_att_sum.shape[0], 4, 4).reshape(channel_att_sum.shape[0], 16) + # scale = F.sigmoid(avg_weight).unsqueeze(2).unsqueeze(3).expand_as(x) + scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale, scale + + +def logsumexp_2d(tensor): + tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) + s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) + outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() + return outputs + + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1) + + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) + + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = F.sigmoid(x_out) # broadcasting + return x * scale, scale + +class SpatialAtten(nn.Module): + def __init__(self, in_size, out_size, kernel_size=3, stride=1): + super(SpatialAtten, self).__init__() + self.conv1 = BasicConv(in_size, out_size, kernel_size, stride=stride, + padding=(kernel_size-1) // 2, relu=True) + self.conv2 = BasicConv(out_size, in_size, kernel_size=1, stride=stride, + padding=0, relu=True, bn=False) + + def forward(self, x): + residual = x + x_out = self.conv1(x) + x_out = self.conv2(x_out) + spatial_att = F.sigmoid(x_out) + # .unsqueeze(4).permute(0, 1, 4, 2, 3) + # spatial_att = spatial_att.expand(spatial_att.shape[0], 4, 4, spatial_att.shape[3], spatial_att.shape[4]).reshape( + # spatial_att.shape[0], 16, spatial_att.shape[3], spatial_att.shape[4]) + x_out = residual * spatial_att + + x_out += residual + + return x_out, spatial_att + +class Scale_atten_block(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(Scale_atten_block, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial = no_spatial + if not no_spatial: + self.SpatialGate = SpatialAtten(gate_channels, gate_channels //reduction_ratio) + + def forward(self, x): + x_out, ca_atten = self.ChannelGate(x) + if not self.no_spatial: + x_out, sa_atten = self.SpatialGate(x_out) + + return x_out, ca_atten, sa_atten + + +class scale_atten_convblock(nn.Module): + def __init__(self, in_size, out_size, stride=1, downsample=None, use_cbam=True, no_spatial=False, drop_out=False): + super(scale_atten_convblock, self).__init__() + self.downsample = downsample + self.stride = stride + self.no_spatial = no_spatial + self.dropout = drop_out + + self.relu = nn.ReLU(inplace=True) + self.conv3 = conv3x3(in_size, out_size) + self.bn3 = nn.BatchNorm2d(out_size) + + if use_cbam: + self.cbam = Scale_atten_block(in_size, reduction_ratio=4, no_spatial=self.no_spatial) # out_size + else: + self.cbam = None + + def forward(self, x): + residual = x + + if self.downsample is not None: + residual = self.downsample(x) + + if not self.cbam is None: + out, scale_c_atten, scale_s_atten = self.cbam(x) + + out += residual + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + out = self.relu(out) + + if self.dropout: + out = nn.Dropout2d(0.5)(out) + + return out + +class _NonLocalBlockND(nn.Module): + def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', + sub_sample_factor=4, bn_layer=True): + super(_NonLocalBlockND, self).__init__() + + assert dimension in [1, 2, 3] + assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] + + # print('Dimension: %d, mode: %s' % (dimension, mode)) + + self.mode = mode + self.dimension = dimension + self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] + + self.in_channels = in_channels + self.inter_channels = inter_channels + + if self.inter_channels is None: + self.inter_channels = in_channels // 2 + if self.inter_channels == 0: + self.inter_channels = 1 + + if dimension == 3: + conv_nd = nn.Conv3d + max_pool = nn.MaxPool3d + bn = nn.BatchNorm3d + elif dimension == 2: + conv_nd = nn.Conv2d + max_pool = nn.MaxPool2d + bn = nn.BatchNorm2d + else: + conv_nd = nn.Conv1d + max_pool = nn.MaxPool1d + bn = nn.BatchNorm1d + + self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if bn_layer: + self.W = nn.Sequential( + conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0), + bn(self.in_channels) + ) + nn.init.constant(self.W[1].weight, 0) + nn.init.constant(self.W[1].bias, 0) + else: + self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + self.theta = None + self.phi = None + + if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: + self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, + kernel_size=1, stride=1, padding=0) + + if mode in ['concatenation']: + self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) + self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) + elif mode in ['concat_proper', 'concat_proper_down']: + self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, + padding=0, bias=True) + + if mode == 'embedded_gaussian': + self.operation_function = self._embedded_gaussian + elif mode == 'dot_product': + self.operation_function = self._dot_product + elif mode == 'gaussian': + self.operation_function = self._gaussian + elif mode == 'concatenation': + self.operation_function = self._concatenation + elif mode == 'concat_proper': + self.operation_function = self._concatenation_proper + elif mode == 'concat_proper_down': + self.operation_function = self._concatenation_proper_down + else: + raise NotImplementedError('Unknown operation function.') + + if any(ss > 1 for ss in self.sub_sample_factor): + self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) + if self.phi is None: + self.phi = max_pool(kernel_size=sub_sample_factor) + else: + self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) + if mode == 'concat_proper_down': + self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) + + # Initialise weights + for m in self.children(): + init_weights(m, init_type='kaiming') + + def forward(self, x): + ''' + :param x: (b, c, t, h, w) + :return: + ''' + + output = self.operation_function(x) + return output + + def _embedded_gaussian(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _gaussian(self, x): + batch_size = x.size(0) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = x.view(batch_size, self.in_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + + if self.sub_sample_factor > 1: + phi_x = self.phi(x).view(batch_size, self.in_channels, -1) + else: + phi_x = x.view(batch_size, self.in_channels, -1) + + f = torch.matmul(theta_x, phi_x) + f_div_C = F.softmax(f, dim=-1) + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _dot_product(self, x): + batch_size = x.size(0) + + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + g_x = g_x.permute(0, 2, 1) + + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + theta_x = theta_x.permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + f = torch.matmul(theta_x, phi_x) + N = f.size(-1) + f_div_C = f / N + + y = torch.matmul(f_div_C, g_x) + y = y.permute(0, 2, 1).contiguous() + y = y.view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) + + # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) + # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) + # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) + f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ + self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) + f = F.relu(f, inplace=True) + + # Normalise the relations + N = f.size(-1) + f_div_c = f / N + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) + W_y = self.W(y) + z = W_y + x + + return z + + def _concatenation_proper_down(self, x): + batch_size = x.size(0) + + # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) + g_x = self.g(x).view(batch_size, self.inter_channels, -1) + + # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) + # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) + theta_x = self.theta(x) + downsampled_size = theta_x.size() + theta_x = theta_x.view(batch_size, self.inter_channels, -1) + phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) + + # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) + # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) + # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) + f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ + phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) + f = F.relu(f, inplace=True) + + # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) + f = torch.squeeze(self.psi(f), dim=1) + + # Normalise the relations + f_div_c = F.softmax(f, dim=1) + + # g(x_j) * f(x_j, x_i) + # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) + y = torch.matmul(g_x, f_div_c) + y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) + + # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) + y = F.upsample(y, size=x.size()[2:], mode='trilinear') + + # attention block output + W_y = self.W(y) + z = W_y + x + + return z + + +class NONLocalBlock2D(_NonLocalBlockND): + def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): + super(NONLocalBlock2D, self).__init__(in_channels, + inter_channels=inter_channels, + dimension=2, mode=mode, + sub_sample_factor=sub_sample_factor, + bn_layer=bn_layer) + + +class CANet(nn.Module): + """ + Implementation of CANet (Comprehensive Attention Network) for image segmentation. + + * Reference: R. Gu et al. `CA-Net: Comprehensive Attention Convolutional Neural Networks + for Explainable Medical Image Segmentation `_. + IEEE Transactions on Medical Imaging, 40(2),2021:699-711. + + Parameters are given in the `params` dictionary, and should include the + following fields: + + :param in_chns: (int) Input channel number. + :param feature_chns: (list) Feature channel for each resolution level. + The length should be 5, such as [16, 32, 64, 128, 256]. + :param dropout: (list) The dropout ratio for each resolution level. + The length should be the same as that of `feature_chns`. + :param class_num: (int) The class number for segmentation task. + :param bilinear: (bool) Using bilinear for up-sampling or not. + If False, deconvolution will be used for up-sampling. + """ + def __init__(self, params): #args, in_ch=3, n_classes=2, feature_scale=4, is_deconv=True, is_batchnorm=True, + # nonlocal_mode='concatenation', attention_dsample=(1, 1)): + super(CANet, self).__init__() + self.in_channels = params['in_chns'] + self.num_classes = params['class_num'] + self.is_deconv = params.get('is_deconv', True) + self.is_batchnorm = params.get('is_batchnorm', True) + self.feature_scale = params.get('feature_scale', 4) + nonlocal_mode = 'concatenation' + attention_dsample = (1, 1) + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = conv_block(self.in_channels, filters[0]) + self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv2 = conv_block(filters[0], filters[1]) + self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv3 = conv_block(filters[1], filters[2]) + self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.conv4 = conv_block(filters[2], filters[3], drop_out=True) + self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 2)) + + self.center = conv_block(filters[3], filters[4], drop_out=True) + + # attention blocks + # self.attentionblock1 = GridAttentionBlock2D(in_channels=filters[0], gating_channels=filters[1], + # inter_channels=filters[0]) + self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], + nonlocal_mode=nonlocal_mode, sub_sample_factor=attention_dsample) + self.nonlocal4_2 = NONLocalBlock2D(in_channels=filters[4], inter_channels=filters[4] // 4) + + # upsampling + self.up_concat4 = UpCat(filters[4], filters[3], self.is_deconv) + self.up_concat3 = UpCat(filters[3], filters[2], self.is_deconv) + self.up_concat2 = UpCat(filters[2], filters[1], self.is_deconv) + self.up_concat1 = UpCat(filters[1], filters[0], self.is_deconv) + self.up4 = SE_Conv_Block(filters[4], filters[3], drop_out=True) + self.up3 = SE_Conv_Block(filters[3], filters[2]) + self.up2 = SE_Conv_Block(filters[2], filters[1]) + self.up1 = SE_Conv_Block(filters[1], filters[0]) + + # For deep supervision, project the multi-scale feature maps to the same number of channels + self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=filters[0]//2, kernel_size=1) + self.dsv2 = nn.Conv2d(in_channels=filters[1], out_channels=filters[0]//2, kernel_size=1) + self.dsv3 = nn.Conv2d(in_channels=filters[2], out_channels=filters[0]//2, kernel_size=1) + self.dsv4 = nn.Conv2d(in_channels=filters[3], out_channels=filters[0]//2, kernel_size=1) + + self.scale_att = scale_atten_convblock(in_size=filters[0]//2 * 4, out_size=filters[0]) + self.final = nn.Conv2d(filters[0], self.num_classes, kernel_size=1) + + def forward(self, inputs): + # Feature Extraction + conv1 = self.conv1(inputs) + maxpool1 = self.maxpool1(conv1) + + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + # Gating Signal Generation + center = self.center(maxpool4) + + # Attention Mechanism + # Upscaling Part (Decoder) + up4 = self.up_concat4(conv4, center) + g_conv4 = self.nonlocal4_2(up4) + + up4, att_weight4 = self.up4(g_conv4) + g_conv3, att3 = self.attentionblock3(conv3, up4) + + # atten3_map = att3.cpu().detach().numpy().astype(np.float) + # atten3_map = ndimage.interpolation.zoom(atten3_map, [1.0, 1.0, 224 / atten3_map.shape[2], + # 300 / atten3_map.shape[3]], order=0) + + up3 = self.up_concat3(g_conv3, up4) + up3, att_weight3 = self.up3(up3) + g_conv2, att2 = self.attentionblock2(conv2, up3) + + up2 = self.up_concat2(g_conv2, up3) + up2, att_weight2 = self.up2(up2) + + up1 = self.up_concat1(conv1, up2) + up1, att_weight1 = self.up1(up1) + + # Deep Supervision + dsv1 = self.dsv1(up1) + dsv2 = F.interpolate(self.dsv2(up2), dsv1.shape[2:], mode = 'bilinear') + dsv3 = F.interpolate(self.dsv3(up3), dsv1.shape[2:], mode = 'bilinear') + dsv4 = F.interpolate(self.dsv4(up4), dsv1.shape[2:], mode = 'bilinear') + + dsv_cat = torch.cat([dsv1, dsv2, dsv3, dsv4], dim=1) + out = self.scale_att(dsv_cat) + + out = self.final(out) + + return out + +if __name__ == "__main__": + params = {'in_chns':3, + 'class_num':2} + Net = CANet(params) + Net = Net.double() + + x = np.random.rand(4, 3, 224, 224) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 828bdfe..19a0788 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -25,11 +25,26 @@ class UNet2D_DualBranch(nn.Module): """ def __init__(self, params): super(UNet2D_DualBranch, self).__init__() - self.output_mode = params.get("output_mode", "average") + params = self.get_default_parameters(params) + self.output_mode = params["output_mode"] self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) + def get_default_parameters(self, params): + default_param = { + 'feature_chns': [32, 64, 128, 256, 512], + 'dropout': [0.0, 0.0, 0.2, 0.3, 0.4], + 'up_mode': 2, + 'multiscale_pred': False, + 'output_mode': "average" + } + for key in default_param: + params[key] = params.get(key, default_param[key]) + for key in params: + logging.info("{0:} = {1:}".format(key, params[key])) + return params + def forward(self, x): x_shape = list(x.shape) if(len(x_shape) == 5): diff --git a/pymic/net/net2d/unet2d_mcnet.py b/pymic/net/net2d/unet2d_mcnet.py new file mode 100644 index 0000000..be5b16b --- /dev/null +++ b/pymic/net/net2d/unet2d_mcnet.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch.nn as nn +from pymic.net.net2d.unet2d import * + +class MCNet2D(nn.Module): + """ + A tri-branch network using UNet2D as backbone. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `Medical Image Analysis 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UNet2D` for details. + """ + def __init__(self, params): + super(MCNet2D, self).__init__() + in_chns = params['in_chns'] + class_num = params['class_num'] + params1 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 0, + 'multiscale_pred': False } + params2 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 1, + 'multiscale_pred': False} + params3 = {'in_chns': in_chns, + 'feature_chns': [16, 32, 64, 128, 256], + 'dropout': [0.05, 0.1, 0.2, 0.3, 0.5], + 'class_num': class_num, + 'up_mode': 2, + 'multiscale_pred': False} + self.encoder = Encoder(params1) + self.decoder1 = Decoder(params1) + self.decoder2 = Decoder(params2) + self.decoder3 = Decoder(params3) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + feature = self.encoder(x) + output1 = self.decoder1(feature) + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.transpose(torch.reshape(output1, new_shape), 1, 2) + if(not self.training): + return output1 + output2 = self.decoder2(feature) + output3 = self.decoder3(feature) + if(len(x_shape) == 5): + output2 = torch.transpose(torch.reshape(output2, new_shape), 1, 2) + output3 = torch.transpose(torch.reshape(output3, new_shape), 1, 2) + return output1, output2, output3 diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 125843e..54a5d2f 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn import numpy as np +from pymic.net.net2d.unet2d import UpBlock, Encoder, Decoder, UNet2D from pymic.net.net2d.scse2d import * class ConvScSEBlock(nn.Module): @@ -50,116 +51,64 @@ def __init__(self, in_channels, out_channels, dropout_p): def forward(self, x): return self.maxpool_conv(x) -class UpBlock(nn.Module): +class UpBlockScSE(UpBlock): """Up-sampling followed by `ConvScSEBlock` in U-Net structure. - :param in_channels1: (int) Input channel number for low-resolution feature map. - :param in_channels2: (int) Input channel number for high-resolution feature map. - :param out_channels: (int) Output channel number. - :param dropout_p: (int) Dropout probability. - :param bilinear: (bool) Use bilinear for up-sampling or not. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.UpBlock` for details. """ - def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, - bilinear=True): - super(UpBlock, self).__init__() - self.bilinear = bilinear - if bilinear: - self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size = 1) - self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - else: - self.up = nn.ConvTranspose2d(in_channels1, in_channels2, kernel_size=2, stride=2) + def __init__(self, in_channels1, in_channels2, out_channels, dropout_p, up_mode = 2): + super(UpBlockScSE, self).__init__(in_channels1, in_channels2, out_channels, dropout_p, up_mode) self.conv = ConvScSEBlock(in_channels2 * 2, out_channels, dropout_p) - def forward(self, x1, x2): - if self.bilinear: - x1 = self.conv1x1(x1) - x1 = self.up(x1) - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - -class UNet2D_ScSE(nn.Module): +class EncoderScSE(Encoder): """ - Combining 2D U-Net with SCSE module. + Encoder of 2D UNet with ScSE. - * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: - Recalibrating Fully Convolutional Networks With Spatial and Channel - "Squeeze and Excitation" Blocks. - `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Encoder` for details. """ def __init__(self, params): - super(UNet2D_ScSE, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) + super(EncoderScSE, self).__init__(params) self.in_conv= ConvScSEBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) +class DecoderScSE(Decoder): + """ + Decoder of 2D UNet with ScSE. + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.Decoder` for details. + """ + def __init__(self, params): + super(DecoderScSE, self).__init__(params) - x = self.up1(x4, x3) - x = self.up2(x, x2) - x = self.up3(x, x1) - x = self.up4(x, x0) - output = self.out_conv(x) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - return output - -if __name__ == "__main__": - params = {'in_chns':4, - 'feature_chns':[2, 8, 32, 48, 64], - 'dropout': [0, 0, 0.3, 0.4, 0.5], - 'class_num': 2, - 'bilinear': True} - Net = UNet2D_ScSE(params) - Net = Net.double() - - x = np.random.rand(4, 4, 10, 96, 96) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = Net(xt) - print(len(y.size())) - y = y.detach().numpy() - print(y.shape) \ No newline at end of file + + if(len(self.ft_chns) == 5): + self.up1 = UpBlockScSE(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.up_mode) + self.up2 = UpBlockScSE(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.up_mode) + self.up3 = UpBlockScSE(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.up_mode) + self.up4 = UpBlockScSE(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.up_mode) + + +class UNet2D_ScSE(UNet2D): + """ + Combining 2D U-Net with SCSE module. + + * Reference: Abhijit Guha Roy, Nassir Navab, Christian Wachinger: + Recalibrating Fully Convolutional Networks With Spatial and Channel + "Squeeze and Excitation" Blocks. + `IEEE Trans. Med. Imaging 38(2): 540-549 (2019). `_ + + The parameters for the backbone should be given in the `params` dictionary. + See :mod:`pymic.net.net2d.unet2d.unet2d` for details. + """ + def __init__(self, params): + super(UNet2D_ScSE, self).__init__(params) + self.encoder = Encoder(params) + self.decoder = Decoder(params) diff --git a/pymic/net/net3d/trans3d/unetr_pp.py b/pymic/net/net3d/trans3d/unetr_pp.py index 3ef6736..a4ab7e6 100644 --- a/pymic/net/net3d/trans3d/unetr_pp.py +++ b/pymic/net/net3d/trans3d/unetr_pp.py @@ -155,7 +155,6 @@ def __init__( def forward(self, x): B, C, H, W, D = x.shape - x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: @@ -170,12 +169,13 @@ def forward(self, x): class UnetrPPEncoder(nn.Module): def __init__(self, input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4],dims=[32, 64, 128, 256], - proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, in_channels=1, dropout=0.0, transformer_dropout_rate=0.15 ,**kwargs): + proj_size =[64,64,64,32], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=3, + in_channels=1, dropout=0.0, transformer_dropout_rate=0.15, kernel_size=(2,4,4), **kwargs): super().__init__() self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( - get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=(2, 4, 4), stride=(2, 4, 4), + get_conv_layer(spatial_dims, in_channels, dims[0], kernel_size=kernel_size, stride=kernel_size, dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": in_channels}), channels=dims[0]), ) @@ -209,7 +209,6 @@ def _init_weights(self, m): def forward_features(self, x): hidden_states = [] - x = self.downsample_layers[0](x) x = self.stages[0](x) @@ -330,6 +329,7 @@ def __init__(self, params): in_channels = params['in_chns'] out_channels = params['class_num'] img_size = params['img_size'] + self.res_mode= params.get("resolution_mode", 1) feature_size = params.get('feature_size', 16) hidden_size = params.get('hidden_size', 256) num_heads = params.get('num_heads', 4) @@ -350,15 +350,20 @@ def __init__(self, params): if pos_embed not in ["conv", "perceptron"]: raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") - self.patch_size = (2, 4, 4) + kernel_ds = [4, 2, 1] + kernel_d = kernel_ds[self.res_mode] + self.patch_size = (kernel_d, 4, 4) + self.feat_size = ( img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages ) + self.hidden_size = hidden_size - self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) + self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, + in_channels=in_channels, kernel_size=self.patch_size) self.encoder1 = UnetResBlock( spatial_dims=3, @@ -395,20 +400,21 @@ def __init__(self, params): norm_name=norm_name, out_size=32 * 32 * 32, ) + self.decoder2 = UnetrUpBlock( spatial_dims=3, in_channels=feature_size * 2, out_channels=feature_size, kernel_size=3, - upsample_kernel_size=(2, 4, 4), + upsample_kernel_size= self.patch_size, norm_name=norm_name, - out_size=64 * 128 * 128, + out_size= kernel_d*32 * 128 * 128, conv_decoder=True, ) self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) - if self.do_ds: - self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) - self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) + # if self.do_ds: + self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) + self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) @@ -442,19 +448,22 @@ def forward(self, x_in): if __name__ == "__main__": - params = {'in_chns': 1, - 'class_num': 2, - 'img_size': [64, 128, 128] - } - net = UNETR_PP(params) - net.double() - - x = np.random.rand(2, 1, 64, 128, 128) - xt = torch.from_numpy(x) - xt = torch.tensor(xt) - - y = net(xt) - print(len(y)) - for yi in y: - yi = yi.detach().numpy() - print(yi.shape) \ No newline at end of file + depths = [128, 64, 32] + for i in range(3): + params = {'in_chns': 4, + 'class_num': 2, + 'img_size': [depths[i], 128, 128], + 'resolution_mode': i + } + net = UNETR_PP(params) + net.double() + + x = np.random.rand(2, 4, depths[i], 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = net(xt) + print(len(y)) + for yi in y: + yi = yi.detach().numpy() + print(yi.shape) \ No newline at end of file diff --git a/pymic/net/net_dict_cls.py b/pymic/net/net_dict_cls.py index 7996e59..3a7808b 100644 --- a/pymic/net/net_dict_cls.py +++ b/pymic/net/net_dict_cls.py @@ -3,7 +3,7 @@ Built-in networks for classification. * resnet18 :mod:`pymic.net.cls.torch_pretrained_net.ResNet18` -* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` +* vgg16 :mod:`pymic.net.cls.torch_pretrained_net.VGG16` * mobilenetv2 :mod:`pymic.net.cls.torch_pretrained_net.MobileNetV2` """ diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index ffaa023..e381421 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -7,6 +7,7 @@ * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` +* MCNet2D :mod:`pymic.net.net2d.unet2d_mcnet.MCNet2D` * NestedUNet2D :mod:`pymic.net.net2d.unet2d_nest.NestedUNet2D` * COPLENet :mod:`pymic.net.net2d.cople_net.COPLENet` * UNet2D5 :mod:`pymic.net.net3d.unet2d5.UNet2D5` @@ -16,13 +17,15 @@ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch +from pymic.net.net2d.unet2d_canet import CANet from pymic.net.net2d.unet2d_cct import UNet2D_CCT +from pymic.net.net2d.unet2d_mcnet import MCNet2D from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D from pymic.net.net2d.unet2d_scse import UNet2D_ScSE -# from pymic.net.net2d.trans2d.transunet import TransUNet -# from pymic.net.net2d.trans2d.swinunet import SwinUNet +from pymic.net.net2d.trans2d.transunet import TransUNet +from pymic.net.net2d.trans2d.swinunet import SwinUNet from pymic.net.net3d.unet2d5 import UNet2D5 from pymic.net.net3d.unet3d import UNet3D from pymic.net.net3d.unet3d_scse import UNet3D_ScSE @@ -39,17 +42,20 @@ # from pymic.net.net3d.trans3d.HiFormer_v3 import HiFormer_v3 # from pymic.net.net3d.trans3d.HiFormer_v4 import HiFormer_v4 # from pymic.net.net3d.trans3d.HiFormer_v5 import HiFormer_v5 +# from pymic.net.net3d.trans3d.SwitchNet import SwitchNet SegNetDict = { 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_CCT': UNet2D_CCT, + 'MCNet2D': MCNet2D, + 'CANet': CANet, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, 'UNet2D_ScSE': UNet2D_ScSE, - # 'TransUNet': TransUNet, - # 'SwinUNet': SwinUNet, + 'TransUNet': TransUNet, + 'SwinUNet': SwinUNet, 'UNet2D5': UNet2D5, 'UNet3D': UNet3D, 'UNet3D_ScSE': UNet3D_ScSE, @@ -66,4 +72,5 @@ # 'HiFormer_v3': HiFormer_v3, # 'HiFormer_v4': HiFormer_v4, # 'HiFormer_v5': HiFormer_v5 + # 'SwitchNet': SwitchNet } diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 7a49a2b..f9575ab 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -154,6 +154,15 @@ def get_checkpoint_name(self): ckpt_name = self.config['testing']['ckpt_name'] return ckpt_name + @abstractmethod + def get_stage_transform_from_config(self, stage): + """ + Get the transform list required by dataset for training, validation or inference stage. + + :param stage: (str) `train`, `valid` or `test`. + """ + raise(ValueError("not implemented")) + @abstractmethod def get_stage_dataset_from_config(self, stage): """ @@ -261,13 +270,13 @@ def worker_init_fn(worker_id): bn_train = self.config['dataset']['train_batch_size'] bn_valid = self.config['dataset'].get('valid_batch_size', 1) - num_worker = self.config['dataset'].get('num_worker', 16) + num_worker = self.config['dataset'].get('num_worker', 8) g_train, g_valid = torch.Generator(), torch.Generator() g_train.manual_seed(self.random_seed) g_valid.manual_seed(self.random_seed) self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size = bn_train, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init, generator = g_train) + worker_init_fn=worker_init, generator = g_train, drop_last = True) self.valid_loader = torch.utils.data.DataLoader(self.valid_set, batch_size = bn_valid, shuffle=False, num_workers= num_worker, worker_init_fn=worker_init, generator = g_valid) diff --git a/pymic/net_run/agent_preprocess.py b/pymic/net_run/agent_preprocess.py new file mode 100644 index 0000000..c681de9 --- /dev/null +++ b/pymic/net_run/agent_preprocess.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import os +import sys +import torch +import torchvision.transforms as transforms +from pymic.util.parse_config import * +from pymic.io.image_read_write import save_nd_array_as_image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict + + + +class PreprocessAgent(object): + def __init__(self, config): + super(PreprocessAgent, self).__init__() + self.config = config + self.transform_dict = TransformDict + self.task_type = config['dataset']['task_type'] + self.dataloader = None + self.dataloader_unlab= None + + def get_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']["transform"] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = self.task_type + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + data_csv = self.config['dataset'].get('data_csv', None) + data_csv_unlab = self.config['dataset'].get('data_csv_unlab', None) + if(data_csv is not None): + dataset = NiftyDataset(root_dir = root_dir, + csv_file = data_csv, + modal_num = modal_num, + with_label= True, + transform = data_transform, + task = self.task_type) + self.dataloader = torch.utils.data.DataLoader(dataset, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + if(data_csv_unlab is not None): + dataset_unlab = NiftyDataset(root_dir = root_dir, + csv_file = data_csv_unlab, + modal_num = modal_num, + with_label= False, + transform = data_transform, + task = self.task_type) + self.dataloader_unlab = torch.utils.data.DataLoader(dataset_unlab, + batch_size = 1, shuffle=False, num_workers= 8, + worker_init_fn=None, generator = torch.Generator()) + + def run(self): + """ + Do preprocessing for labeled and unlabeled data. + """ + self.get_dataset_from_config() + out_dir = self.config['dataset']['output_dir'] + for dataloader in [self.dataloader, self.dataloader_unlab]: + for item in dataloader: + img = item['image'][0] # the batch size is 1 + # save differnt modaliteis + img_names = item['names'] + spacing = [x.numpy()[0] for x in item['spacing']] + for i in range(img.shape[0]): + image_name = out_dir + "/" + img_names[i][0] + print(image_name) + save_nd_array_as_image(img[i], image_name, reference_name = None, spacing=spacing) + if('label' in item): + lab = item['label'][0] + label_name = out_dir + "/" + img_names[-1][0] + print(label_name) + save_nd_array_as_image(lab[0], label_name, reference_name = None, spacing=spacing) + +def main(): + """ + The main function for data preprocessing. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_preprocess config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + agent = PreprocessAgent(config) + agent.run() + +if __name__ == "__main__": + main() + diff --git a/pymic/net_run/agent_rec.py b/pymic/net_run/agent_rec.py index 1e58bc6..cd311ad 100644 --- a/pymic/net_run/agent_rec.py +++ b/pymic/net_run/agent_rec.py @@ -29,14 +29,6 @@ class ReconstructionAgent(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(ReconstructionAgent, self).__init__(config, stage) - output_act_name = config['network'].get('output_activation', 'sigmoid') - if(output_act_name == "sigmoid"): - self.out_act = nn.Sigmoid() - elif(output_act_name == "tanh"): - self.out_act = nn.Tanh() - else: - raise ValueError("For reconstruction task, only sigmoid and tanh are " + \ - "supported for output_activation.") def create_loss_calculator(self): if(self.loss_dict is None): @@ -48,7 +40,6 @@ def create_loss_calculator(self): raise ValueError("Undefined loss function {0:}".format(loss_name)) else: loss_param = self.config['training'] - loss_param['loss_softmax'] = False base_loss = self.loss_dict[loss_name](self.config['training']) if(self.config['training'].get('deep_supervise', False)): raise ValueError("Deep supervised loss not implemented for reconstruction tasks") @@ -80,8 +71,13 @@ def training(self): # print(inputs.shape) # for i in range(inputs.shape[0]): # image_i = inputs[i][0] + # label_i = label[i][0] # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # if(it > 10): + # break # return inputs, label = inputs.to(self.device), label.to(self.device) @@ -91,7 +87,18 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) - outputs = self.out_act(outputs) + + # for debug + # if it < 5: + # outputs = nn.Tanh()(outputs) + # for i in range(inputs.shape[0]): + # out_name = "temp/output_{0:}_{1:}.nii.gz".format(it, i) + # output = outputs[i][0] + # output = output.cpu().detach().numpy() + # save_nd_array_as_image(output, out_name, reference_name = None) + # else: + # break + loss = self.get_loss_value(data, outputs, label) loss.backward() self.optimizer.step() @@ -123,7 +130,6 @@ def validation(self): label = self.convert_tensor_type(data['label']) inputs, label = inputs.to(self.device), label.to(self.device) outputs = self.inferer.run(self.net, inputs) - outputs = self.out_act(outputs) # The tensors are on CPU when calculating loss for validation data loss = self.get_loss_value(data, outputs, label) valid_loss_list.append(loss.item()) @@ -187,7 +193,7 @@ def train_valid(self): self.min_val_loss = 10000.0 self.max_val_it = 0 self.best_model_wts = None - self.checkpoint = None + checkpoint = None # initialize the network with pre-trained weights ckpt_init_name = self.config['training'].get('ckpt_init_name', None) ckpt_init_mode = self.config['training'].get('ckpt_init_mode', 0) @@ -206,7 +212,7 @@ def train_valid(self): else: self.net.load_state_dict(pretrained_dict, strict = False) if(ckpt_init_mode > 0): # Load other information - self.min_val_loss = self.checkpoint.get('valid_loss', 10000) + self.min_val_loss = checkpoint.get('valid_loss', 10000) iter_start = checkpoint['iteration'] self.max_val_it = iter_start self.best_model_wts = checkpoint['model_state_dict'] @@ -293,19 +299,20 @@ def save_outputs(self, data): names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): pred = pred[0] - if(isinstance(self.out_act, nn.Sigmoid)): - pred = scipy.special.expit(pred) - else: - pred = np.tanh(pred) + pred = np.tanh(pred) + # pred = scipy.special.expit(pred) # save the output predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(pred[i][i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(pred[i][i], save_name, test_dir + '/' + names[i][0]) \ No newline at end of file diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 70f01b3..2d6d489 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -18,6 +18,7 @@ from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset from pymic.net.net_dict_seg import SegNetDict +from pymic.net.multi_net import MultiNet from pymic.net_run.agent_abstract import NetRunAgent from pymic.net_run.infer_func import Inferer from pymic.loss.loss_dict_seg import SegLossDict @@ -38,36 +39,42 @@ def __init__(self, config, stage = 'train'): self.net_dict = SegNetDict self.postprocess_dict = PostProcessDict self.postprocessor = None - - def get_stage_dataset_from_config(self, stage): - assert(stage in ['train', 'valid', 'test']) - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset'].get('modal_num', 1) + def get_transform_names_and_parameters(self, stage): + """ + Get a list of transform objects for creating a dataset + """ + assert(stage in ['train', 'valid', 'test']) transform_key = stage + '_transform' - if(stage == "valid" and transform_key not in self.config['dataset']): - transform_key = "train_transform" - transform_names = self.config['dataset'][transform_key] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = self.task_type - for name in transform_names: + trans_names = self.config['dataset'][transform_key] + trans_params = self.config['dataset'] + trans_params['task'] = self.task_type + return trans_names, trans_params + + def get_stage_dataset_from_config(self, stage): + trans_names, trans_params = self.get_transform_names_and_parameters(stage) + transform_list = [] + if(trans_names is not None and len(trans_names) > 0): + for name in trans_names: if(name not in self.transform_dict): raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) + one_transform = self.transform_dict[name](trans_params) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) - csv_file = self.config['dataset'].get(stage + '_csv', None) + csv_file = self.config['dataset'].get(stage + '_csv', None) if(stage == 'test'): with_label = False + self.test_transforms = transform_list else: with_label = self.config['dataset'].get(stage + '_label', True) - dataset = NiftyDataset(root_dir = root_dir, + modal_num = self.config['dataset'].get('modal_num', 1) + stage_dir = self.config['dataset'].get('train_dir', None) + if(stage == 'valid' and "valid_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['valid_dir'] + if(stage == 'test' and "test_dir" in self.config['dataset']): + stage_dir = self.config['dataset']['test_dir'] + dataset = NiftyDataset(root_dir = stage_dir, csv_file = csv_file, modal_num = modal_num, with_label= with_label, @@ -78,13 +85,18 @@ def get_stage_dataset_from_config(self, stage): def create_network(self): if(self.net is None): net_name = self.config['network']['net_type'] - if(net_name not in self.net_dict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net = self.net_dict[net_name](self.config['network']) + if(isinstance(net_name, (tuple, list))): + self.net = MultiNet(self.net_dict, self.config['network']) + else: + if(net_name not in self.net_dict): + raise ValueError("Undefined network {0:}".format(net_name)) + self.net = self.net_dict[net_name](self.config['network']) if(self.tensor_type == 'float'): self.net.float() else: self.net.double() + if(hasattr(self.net, "set_stage")): + self.net.set_stage(self.stage) param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) logging.info('parameter number {0:}'.format(param_number)) @@ -164,10 +176,13 @@ def training(self): if(mixup_prob > 0 and random() < mixup_prob): inputs, labels_prob = mixup(inputs, labels_prob) - # # for debug + # for debug + # if(it > 10): + # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] - # label_i = labels_prob[i][1] + # # label_i = labels_prob[i][1] + # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) @@ -176,7 +191,8 @@ def training(self): # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) - # # continue + # continue + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) @@ -226,6 +242,9 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) + if('label_prob' not in data): + raise ValueError("label_prob is not found in validation data, make sure" + + "that LabelToProbability is used in valid_transform.") labels_prob = self.convert_tensor_type(data['label_prob']) inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) batch_n = inputs.shape[0] @@ -271,6 +290,27 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + def load_pretrained_weights(self, network, pretrained_dict, device_ids): + if(len(device_ids) > 1): + if(hasattr(network.module, "get_parameters_to_load")): + model_dict = network.module.get_parameters_to_load() + else: + model_dict = network.module.state_dict() + else: + if(hasattr(network, "get_parameters_to_load")): + model_dict = network.get_parameters_to_load() + else: + model_dict = network.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ + k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} + logging.info("Initializing the following parameters with pre-trained model") + for k in pretrained_dict: + logging.info(k) + if (len(device_ids) > 1): + network.module.load_state_dict(pretrained_dict, strict = False) + else: + network.load_state_dict(pretrained_dict, strict = False) + def train_valid(self): device_ids = self.config['training']['gpus'] if(len(device_ids) > 1): @@ -310,16 +350,7 @@ def train_valid(self): if(ckpt_init_name is not None): checkpoint = torch.load(ckpt_dir + "/" + ckpt_init_name, map_location = self.device) pretrained_dict = checkpoint['model_state_dict'] - model_dict = self.net.module.state_dict() if (len(device_ids) > 1) else self.net.state_dict() - pretrained_dict = {k: v for k, v in pretrained_dict.items() if \ - k in model_dict and tensor_shape_match(pretrained_dict[k], model_dict[k])} - logging.info("Initializing the following parameters with pre-trained model") - for k in pretrained_dict: - logging.info(k) - if (len(device_ids) > 1): - self.net.module.load_state_dict(pretrained_dict, strict = False) - else: - self.net.load_state_dict(pretrained_dict, strict = False) + self.load_pretrained_weights(self.net, pretrained_dict, device_ids) if(ckpt_init_mode > 0): # Load other information self.max_val_dice = checkpoint.get('valid_pred', 0) @@ -452,7 +483,7 @@ def test_time_dropout(m): pred = pred.cpu().numpy() data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -506,7 +537,7 @@ def infer_with_multiple_checkpoints(self): pred = np.mean(predict_list, axis=0) data['predict'] = pred # inverse transform - for transform in self.transform_list[::-1]: + for transform in self.test_transforms[::-1]: if (transform.inverse): data = transform.inverse_transform_for_prediction(data) @@ -545,15 +576,18 @@ def save_outputs(self, data): for i in range(len(names)): output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions - root_dir = self.config['dataset']['root_dir'] + test_dir = self.config['dataset'].get('test_dir', None) + if(test_dir is None): + test_dir = self.config['dataset']['train_dir'] + for i in range(len(names)): - save_name = names[i].split('/')[-1] if ignore_dir else \ - names[i].replace('/', '_') + save_name = names[i][0].split('/')[-1] if ignore_dir else \ + names[i][0].replace('/', '_') if((filename_replace_source is not None) and (filename_replace_target is not None)): save_name = save_name.replace(filename_replace_source, filename_replace_target) print(save_name) save_name = "{0:}/{1:}".format(output_dir, save_name) - save_nd_array_as_image(output[i], save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(output[i], save_name, test_dir + '/' + names[i][0]) save_name_split = save_name.split('.') if(not save_prob): @@ -571,4 +605,4 @@ def save_outputs(self, data): prob_save_name = "{0:}_prob_{1:}.{2:}".format(save_prefix, c, save_format) if(len(temp_prob.shape) == 2): temp_prob = np.asarray(temp_prob * 255, np.uint8) - save_nd_array_as_image(temp_prob, prob_save_name, root_dir + '/' + names[i]) + save_nd_array_as_image(temp_prob, prob_save_name, test_dir + '/' + names[i][0]) diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index ad8fda0..771b448 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -13,8 +13,9 @@ def get_optimizer(name, net_params, optim_params): # see https://www.codeleading.com/article/44815584159/ param_group = [{'params': net_params, 'initial_lr': lr}] if(keyword_match(name, "SGD")): + nesterov = optim_params.get('nesterov', True) return optim.SGD(param_group, lr, - momentum = momentum, weight_decay = weight_decay) + momentum = momentum, weight_decay = weight_decay, nesterov = nesterov) elif(keyword_match(name, "Adam")): return optim.Adam(param_group, lr, weight_decay = weight_decay) elif(keyword_match(name, "SparseAdam")): diff --git a/pymic/net_run/preprocess.py b/pymic/net_run/preprocess.py new file mode 100644 index 0000000..f2bbe0f --- /dev/null +++ b/pymic/net_run/preprocess.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +import shutil +from datetime import datetime +from pymic import TaskType +from pymic.util.parse_config import * +from pymic.net_run.agent_cls import ClassificationAgent +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.semi_sup import SSLMethodDict +from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict +from pymic.net_run.noisy_label import NLLMethodDict +# from pymic.net_run.self_sup import SelfSLSegAgent + +def get_seg_rec_agent(config, sup_type): + assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) + if(sup_type == 'fully_sup'): + logging.info("\n********** Fully Supervised Learning **********\n") + agent = SegmentationAgent(config, 'train') + elif(sup_type == 'semi_sup'): + logging.info("\n********** Semi Supervised Learning **********\n") + method = config['semi_supervised_learning']['method_name'] + agent = SSLMethodDict[method](config, 'train') + elif(sup_type == 'weak_sup'): + logging.info("\n********** Weakly Supervised Learning **********\n") + method = config['weakly_supervised_learning']['method_name'] + agent = WSLMethodDict[method](config, 'train') + elif(sup_type == 'noisy_label'): + logging.info("\n********** Noisy Label Learning **********\n") + method = config['noisy_label_learning']['method_name'] + agent = NLLMethodDict[method](config, 'train') + elif(sup_type == 'self_sup'): + logging.info("\n********** Self Supervised Learning **********\n") + method = config['self_supervised_learning']['method_name'] + agent = SelfSupMethodDict[method](config, 'train') + else: + raise ValueError("undefined supervision type: {0:}".format(sup_type)) + return agent + +def main(): + """ + The main function for running a network for training. + """ + if(len(sys.argv) < 2): + print('Number of arguments should be 2. e.g.') + print(' pymic_train config.cfg') + exit() + cfg_file = str(sys.argv[1]) + if(not os.path.isfile(cfg_file)): + raise ValueError("The config file does not exist: " + cfg_file) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.makedirs(log_dir, exist_ok=True) + dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] + shutil.copy(cfg_file, log_dir + "/" + dst_cfg) + datetime_str = str(datetime.now())[:-7].replace(":", "_") + if sys.version.startswith("3.9"): + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s', force=True) # for python 3.9 + else: + logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(datetime_str), + level=logging.INFO, format='%(message)s') # for python 3.6 + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + task = config['dataset']['task_type'] + if(task == TaskType.CLASSIFICATION_ONE_HOT or task == TaskType.CLASSIFICATION_COEXIST): + agent = ClassificationAgent(config, 'train') + else: + sup_type = config['dataset'].get('supervise_type', 'fully_sup') + agent = get_seg_rec_agent(config, sup_type) + + agent.run() + +if __name__ == "__main__": + main() + + diff --git a/pymic/net_run/self_sup/__init__.py b/pymic/net_run/self_sup/__init__.py index 73308e6..d73e42a 100644 --- a/pymic/net_run/self_sup/__init__.py +++ b/pymic/net_run/self_sup/__init__.py @@ -1,3 +1,10 @@ from __future__ import absolute_import -from pymic.net_run.self_sup.self_sl_agent import SelfSLSegAgent -from pymic.net_run.self_sup.self_patch_mix_agent import SelfSLPatchMixAgent \ No newline at end of file +from pymic.net_run.self_sup.self_genesis import SelfSupModelGenesis +from pymic.net_run.self_sup.self_patch_swapping import SelfSupPatchSwapping +from pymic.net_run.self_sup.self_volume_fusion import SelfSupVolumeFusion + +SelfSupMethodDict = { + 'ModelGenesis': SelfSupModelGenesis, + 'PatchSwapping': SelfSupPatchSwapping, + 'VolumeFusion': SelfSupVolumeFusion + } \ No newline at end of file diff --git a/pymic/net_run/self_sup/self_genesis.py b/pymic/net_run/self_sup/self_genesis.py new file mode 100644 index 0000000..85ee194 --- /dev/null +++ b/pymic/net_run/self_sup/self_genesis.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupModelGenesis(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = ModelGenesis + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupModelGenesis, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupModelGenesis, self).get_transform_names_and_parameters(stage) + # if(stage == 'train'): + # print('training transforms:', trans_names) + # if("LocalShuffling" not in trans_names): + # raise ValueError("LocalShuffling is required for model genesis, \ + # but it is not given in training transform") + # if("NonLinearTransform" not in trans_names): + # raise ValueError("NonLinearTransform is required for model genesis, \ + # but it is not given in training transform") + # if("InOutPainting" not in trans_names): + # raise ValueError("InOutPainting is required for model genesis, \ + # but it is not given in training transform") + return trans_names, trans_params diff --git a/pymic/net_run/self_sup/self_patch_swapping.py b/pymic/net_run/self_sup/self_patch_swapping.py new file mode 100644 index 0000000..1692fa7 --- /dev/null +++ b/pymic/net_run/self_sup/self_patch_swapping.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import copy +import logging +import time +from pymic.net_run.agent_rec import ReconstructionAgent + +class SelfSupPatchSwapping(ReconstructionAgent): + """ + Patch swapping-based self-supervised learning. + + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + + A PatchSwaping transform need to be used in the cnfiguration. + + .. note:: + + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. + + In the configuration file, it should look like this: + ``` + [dataset] + task_type = rec + supervise_type = self_sup + train_transform = [..., ..., PatchSwaping] + valid_transform = [..., ..., PatchSwaping] + + [self_supervised_learning] + method_name = PatchSwapping + + """ + def __init__(self, config, stage = 'train'): + super(SelfSupPatchSwapping, self).__init__(config, stage) + + def get_transform_names_and_parameters(self, stage): + trans_names, trans_params = super(SelfSupPatchSwapping, self).get_transform_names_and_parameters(stage) + if(stage == 'train'): + print('training transforms:', trans_names) + assert("PatchSwaping" in trans_names) + return trans_names, trans_params + diff --git a/pymic/net_run/self_sup/self_sl_agent.py b/pymic/net_run/self_sup/self_sl_agent.py index c352adf..45bee26 100644 --- a/pymic/net_run/self_sup/self_sl_agent.py +++ b/pymic/net_run/self_sup/self_sl_agent.py @@ -6,6 +6,7 @@ from pymic.net_run.agent_rec import ReconstructionAgent + class SelfSLSegAgent(ReconstructionAgent): """ Abstract class for self-supervised segmentation. @@ -17,7 +18,7 @@ class SelfSLSegAgent(ReconstructionAgent): In the configuration dictionary, in addition to the four sections (`dataset`, `network`, `training` and `inference`) used in fully supervised learning, an - extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + extra section `self_supervised_learning` is needed. See :doc:`usage.selfsl` for details. """ def __init__(self, config, stage = 'train'): super(SelfSLSegAgent, self).__init__(config, stage) diff --git a/pymic/net_run/self_sup/self_patch_mix_agent.py b/pymic/net_run/self_sup/self_volume_fusion.py similarity index 71% rename from pymic/net_run/self_sup/self_patch_mix_agent.py rename to pymic/net_run/self_sup/self_volume_fusion.py index e30a131..91fe088 100644 --- a/pymic/net_run/self_sup/self_patch_mix_agent.py +++ b/pymic/net_run/self_sup/self_volume_fusion.py @@ -32,11 +32,13 @@ from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label from pymic.util.parse_config import * +from pymic.util.general import get_one_hot_seg from pymic.io.image_read_write import save_nd_array_as_image -from pymic.net_run.self_sup.util import patch_mix +from pymic.net_run.self_sup.util import volume_fusion from pymic.net_run.agent_seg import SegmentationAgent -class SelfSLPatchMixAgent(SegmentationAgent): + +class SelfSupVolumeFusion(SegmentationAgent): """ Abstract class for self-supervised segmentation. @@ -50,16 +52,15 @@ class SelfSLPatchMixAgent(SegmentationAgent): extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. """ def __init__(self, config, stage = 'train'): - super(SelfSLPatchMixAgent, self).__init__(config, stage) + super(SelfSupVolumeFusion, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - fg_num = self.config['network']['class_num'] - 1 - patch_num = self.config['patch_mix']['patch_num_range'] - size_d = self.config['patch_mix']['patch_depth_range'] - size_h = self.config['patch_mix']['patch_height_range'] - size_w = self.config['patch_mix']['patch_width_range'] + cls_num = self.config['network']['class_num'] + block_range = self.config['self_supervised_learning']['VolumeFusion_block_range'.lower()] + size_min = self.config['self_supervised_learning']['VolumeFusion_size_min'.lower()] + size_max = self.config['self_supervised_learning']['VolumeFusion_size_max'.lower()] train_loss = 0 train_dice_list = [] @@ -72,16 +73,16 @@ def training(self): data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) - inputs, labels_prob = patch_mix(inputs, fg_num, patch_num, size_d, size_h, size_w) + inputs, labels = volume_fusion(inputs, cls_num - 1, block_range, size_min, size_max) + labels_prob = get_one_hot_seg(labels, cls_num) - # # for debug + # for debug # if(it==10): # break # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = np.argmax(labels_prob[i], axis = 0) # # pixw_i = pix_w[i][0] - # print(image_i.shape, label_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) @@ -116,29 +117,3 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'avg_fg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers - -def main(): - cfg_file = str(sys.argv[1]) - if(not os.path.isfile(cfg_file)): - raise ValueError("The config file does not exist: " + cfg_file) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.makedirs(log_dir, exist_ok=True) - dst_cfg = cfg_file if "/" not in cfg_file else cfg_file.split("/")[-1] - shutil.copy(cfg_file, log_dir + "/" + dst_cfg) - if sys.version.startswith("3.9"): - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s', force=True) # for python 3.9 - else: - logging.basicConfig(filename=log_dir+"/log_train_{0:}.txt".format(str(datetime.now())[:-7]), - level=logging.INFO, format='%(message)s') # for python 3.6 - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = SelfSLPatchMixAgent(config) - agent.run() - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/pymic/net_run/self_sup/util.py b/pymic/net_run/self_sup/util.py index 9cffaa7..db27702 100644 --- a/pymic/net_run/self_sup/util.py +++ b/pymic/net_run/self_sup/util.py @@ -7,7 +7,7 @@ from scipy import ndimage from pymic.io.image_read_write import * from pymic.util.image_process import * -from pymic.util.general import get_one_hot_seg + def get_human_region_mask(img): """ @@ -19,6 +19,10 @@ def get_human_region_mask(img): mask = np.asarray(img > -600) se = np.ones([3,3,3]) mask = ndimage.binary_opening(mask, se, iterations = 2) + D, H, W = mask.shape + for h in range(H): + if(mask[:,h,:].sum() < 2000): + mask[:,h, :] = np.zeros((D, W)) mask = get_largest_k_components(mask, 1) mask_close = ndimage.binary_closing(mask, se, iterations = 2) @@ -47,20 +51,39 @@ def get_human_region_mask(img): fg = np.asarray(fg, np.uint8) return fg -def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): +def get_human_region_mask_fast(img, itk_spacing): + # downsample + D, H, W = img.shape + # scale_down = [1, 1, 1] + if(itk_spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + return mask + +def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None, z_axis_density = 0.5): """ Crop a CT scan based on the bounding box of the human region. """ img_obj = sitk.ReadImage(input_img) - img = sitk.GetArrayFromImage(img_obj) - mask = np.asarray(img > -600) - se = np.ones([3,3,3]) - mask = ndimage.binary_opening(mask, se, iterations = 2) - mask = get_largest_k_components(mask, 1) - bbmin, bbmax = get_ND_bounding_box(mask, margin = [5, 10, 10]) + img = sitk.GetArrayFromImage(img_obj) + mask = np.asarray(img > -600) + mask2d = np.mean(mask, axis = 0) > z_axis_density + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + bbmin = [0] + bbmin + bbmax = [img.shape[0]] + bbmax img_sub = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) img_sub_obj = sitk.GetImageFromArray(img_sub) img_sub_obj.SetSpacing(img_obj.GetSpacing()) + img_sub_obj.SetDirection(img_obj.GetDirection()) sitk.WriteImage(img_sub_obj, output_img) if(input_lab is not None): lab_obj = sitk.ReadImage(input_lab) @@ -70,110 +93,80 @@ def crop_ct_scan(input_img, output_img, input_lab = None, output_lab = None): lab_sub_obj.SetSpacing(img_obj.GetSpacing()) sitk.WriteImage(lab_sub_obj, output_lab) +def get_human_body_mask_and_crop(input_dir, out_img_dir, out_mask_dir): + if(not os.path.exists(out_img_dir)): + os.mkdir(out_img_dir) + os.mkdir(out_mask_dir) + + img_names = [item for item in os.listdir(input_dir) if "nii.gz" in item] + img_names = sorted(img_names) + for img_name in img_names: + print(img_name) + input_name = input_dir + "/" + img_name + out_name = out_img_dir + "/" + img_name + mask_name = out_mask_dir + "/" + img_name + if(os.path.isfile(out_name)): + continue + img_obj = sitk.ReadImage(input_name) + img = sitk.GetArrayFromImage(img_obj) + spacing = img_obj.GetSpacing() + + # downsample + D, H, W = img.shape + spacing = img_obj.GetSpacing() + # scale_down = [1, 1, 1] + if(spacing[2] <= 1): + scale_down = [1/2, 1/2, 1/2] + else: + scale_down = [1, 1/2, 1/2] + img_sub = ndimage.interpolation.zoom(img, scale_down, order = 0) + mask = get_human_region_mask(img_sub) + D1, H1, W1 = mask.shape + scale_up = [D/D1, H/H1, W/W1] + mask = ndimage.interpolation.zoom(mask, scale_up, order = 0) + + bbmin, bbmax = get_ND_bounding_box(mask) + img_crop = crop_ND_volume_with_bounding_box(img, bbmin, bbmax) + mask_crop = crop_ND_volume_with_bounding_box(mask, bbmin, bbmax) + + out_img_obj = sitk.GetImageFromArray(img_crop) + out_img_obj.SetSpacing(spacing) + sitk.WriteImage(out_img_obj, out_name) + mask_obj = sitk.GetImageFromArray(mask_crop) + mask_obj.CopyInformation(out_img_obj) + sitk.WriteImage(mask_obj, mask_name) -def patch_mix(x, fg_num, patch_num, size_d, size_h, size_w): + +def volume_fusion(x, fg_num, block_range, size_min, size_max): """ - Copy a sub region of an impage and paste to another one to generate + Fuse a subregion of an impage with another one to generate images and labels for self-supervised segmentation. + input x should be a batch of tensors """ + #n_min, n_max, N, C, D, H, W = list(x.shape) - fg_mask = torch.zeros_like(x) + fg_mask = torch.zeros_like(x).to(torch.int32) # generate mask for n in range(N): - p_num = random.randint(patch_num[0], patch_num[1]) + p_num = random.randint(block_range[0], block_range[1]) for i in range(p_num): - d = random.randint(size_d[0], size_d[1]) - h = random.randint(size_h[0], size_h[1]) - w = random.randint(size_w[0], size_w[1]) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = torch.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) + d = random.randint(size_min[0], size_max[0]) + h = random.randint(size_min[1], size_max[1]) + w = random.randint(size_min[2], size_max[2]) + dc = random.randint(0, D - 1) + hc = random.randint(0, H - 1) + wc = random.randint(0, W - 1) + d0 = dc - d // 2 + h0 = hc - h // 2 + w0 = wc - w // 2 + d1 = min(D, d0 + d) + h1 = min(H, h0 + h) + w1 = min(W, w0 + w) + d0, h0, w0 = max(0, d0), max(0, h0), max(0, w0) + temp_m = torch.ones([C, d1 - d0, h1 - h0, w1 - w0]) * random.randint(1, fg_num) fg_mask[n, :, d0:d1, h0:h1, w0:w1] = temp_m fg_w = fg_mask * 1.0 / fg_num x_roll = torch.roll(x, 1, 0) x_fuse = fg_w*x_roll + (1.0 - fg_w)*x - y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) - return x_fuse, y_prob - -def create_mixed_dataset(input_dir, output_dir, fg_num = 1, crop_num = 1, - mask_dir = None, data_format = "nii.gz"): - """ - Create dataset based on patch mix. - - :param input_dir: (str) The path of folder for input images - :param output_dir: (str) The path of folder for output images - :param fg_num: (int) The number of foreground classes - :param crop_num: (int) The number of patches to crop for each input image - :param mask: ND array to specify a mask, or 'default' or None. If default, - a mask for body region is automatically generated (just for CT). - :param data_format: (str) The format of images. - """ - img_names = os.listdir(input_dir) - img_names = [item for item in img_names if item.endswith(data_format)] - img_names = sorted(img_names) - out_img_dir = output_dir + "/image" - out_lab_dir = output_dir + "/label" - if(not os.path.exists(out_img_dir)): - os.mkdir(out_img_dir) - if(not os.path.exists(out_lab_dir)): - os.mkdir(out_lab_dir) - - img_num = len(img_names) - print("image number", img_num) - i_range = range(img_num) - j_range = list(i_range) - random.shuffle(j_range) - for i in i_range: - print(i, img_names[i]) - j = j_range[i] - if(i == j): - j = i + 1 if i < img_num - 1 else 0 - img_i = load_image_as_nd_array(input_dir + "/" + img_names[i])['data_array'] - img_j = load_image_as_nd_array(input_dir + "/" + img_names[j])['data_array'] - - chns = img_i.shape[0] - # random crop to patch size - if(mask_dir is None): - mask_i = get_human_region_mask(img_i) - mask_j = get_human_region_mask(img_j) - else: - mask_i = load_image_as_nd_array(mask_dir + "/" + img_names[i])['data_array'] - mask_j = load_image_as_nd_array(mask_dir + "/" + img_names[j])['data_array'] - for k in range(crop_num): - # if(mask is None): - # img_ik = random_crop_ND_volume(img_i, [chns, 96, 96, 96]) - # img_jk = random_crop_ND_volume(img_j, [chns, 96, 96, 96]) - # else: - img_ik = random_crop_ND_volume_with_mask(img_i, [chns, 96, 96, 96], mask_i) - img_jk = random_crop_ND_volume_with_mask(img_j, [chns, 96, 96, 96], mask_j) - C, D, H, W = img_ik.shape - # generate mask - fg_mask = np.zeros_like(img_ik, np.uint8) - patch_num = random.randint(4, 40) - for patch in range(patch_num): - d = random.randint(4, 20) # half of window size - h = random.randint(4, 40) - w = random.randint(4, 40) - d_c = random.randint(0, D) - h_c = random.randint(0, H) - w_c = random.randint(0, W) - d0, d1 = max(0, d_c - d), min(D, d_c + d) - h0, h1 = max(0, h_c - h), min(H, h_c + h) - w0, w1 = max(0, w_c - w), min(W, w_c + w) - temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, fg_num) - fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m - fg_w = fg_mask * 1.0 / fg_num - x_fuse = fg_w*img_jk + (1.0 - fg_w)*img_ik - - out_name = img_names[i] - if crop_num > 1: - out_name = out_name.replace(".nii.gz", "_{0:}.nii.gz".format(k)) - save_nd_array_as_image(x_fuse[0], out_img_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - save_nd_array_as_image(fg_mask[0], out_lab_dir + "/" + out_name, - reference_name = input_dir + "/" + img_names[i]) - + # y_prob = get_one_hot_seg(fg_mask.to(torch.int32), fg_num + 1) + return x_fuse, fg_mask diff --git a/pymic/net_run/semi_sup/__init__.py b/pymic/net_run/semi_sup/__init__.py index be753c2..d3095f6 100644 --- a/pymic/net_run/semi_sup/__init__.py +++ b/pymic/net_run/semi_sup/__init__.py @@ -2,6 +2,7 @@ from pymic.net_run.semi_sup.ssl_abstract import SSLSegAgent from pymic.net_run.semi_sup.ssl_em import SSLEntropyMinimization from pymic.net_run.semi_sup.ssl_mt import SSLMeanTeacher +from pymic.net_run.semi_sup.ssl_mcnet import SSLMCNet from pymic.net_run.semi_sup.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run.semi_sup.ssl_cct import SSLCCT from pymic.net_run.semi_sup.ssl_cps import SSLCPS @@ -10,6 +11,7 @@ SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, + 'MCNet': SSLMCNet, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'CCT': SSLCCT, 'CPS': SSLCPS, diff --git a/pymic/net_run/semi_sup/ssl_abstract.py b/pymic/net_run/semi_sup/ssl_abstract.py index b27edc9..0e05281 100644 --- a/pymic/net_run/semi_sup/ssl_abstract.py +++ b/pymic/net_run/semi_sup/ssl_abstract.py @@ -35,7 +35,7 @@ def get_unlabeled_dataset_from_config(self): """ Create a dataset for the unlabeled images based on configuration. """ - root_dir = self.config['dataset']['root_dir'] + train_dir = self.config['dataset']['train_dir'] modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] @@ -53,7 +53,7 @@ def get_unlabeled_dataset_from_config(self): data_transform = transforms.Compose(self.transform_list) csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, + dataset = NiftyDataset(root_dir = train_dir, csv_file = csv_file, modal_num = modal_num, with_label= False, @@ -76,7 +76,7 @@ def worker_init_fn(worker_id): num_worker = self.config['dataset'].get('num_worker', 16) self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) + worker_init_fn=worker_init, drop_last = True) def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run/semi_sup/ssl_cps.py b/pymic/net_run/semi_sup/ssl_cps.py index 4a3be9c..7acfe17 100644 --- a/pymic/net_run/semi_sup/ssl_cps.py +++ b/pymic/net_run/semi_sup/ssl_cps.py @@ -3,29 +3,14 @@ import logging import numpy as np import torch -import torch.nn as nn +from random import random from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice +from pymic.io.image_read_write import save_nd_array_as_image from pymic.net_run.semi_sup import SSLSegAgent -from pymic.net.net_dict_seg import SegNetDict from pymic.util.ramps import get_rampup_ratio - -class BiNet(nn.Module): - def __init__(self, params): - super(BiNet, self).__init__() - net_name = params['net_type'] - self.net1 = SegNetDict[net_name](params) - self.net2 = SegNetDict[net_name](params) - - def forward(self, x): - out1 = self.net1(x) - out2 = self.net2(x) - - if(self.training): - return out1, out2 - else: - return (out1 + out2) / 2 +from pymic.util.general import mixup, tensor_shape_match class SSLCPS(SSLSegAgent): """ @@ -47,19 +32,12 @@ class SSLCPS(SSLSegAgent): def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - def create_network(self): - if(self.net is None): - self.net = BiNet(self.config['network']) - if(self.tensor_type == 'float'): - self.net.float() - else: - self.net.double() - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] - iter_max = self.config['training']['iter_max'] + iter_max = self.config['training']['iter_max'] + mixup_prob = self.config['training'].get('mixup_probability', 0.0) rampup_start = ssl_cfg.get('rampup_start', 0) rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 @@ -83,13 +61,27 @@ def training(self): x0 = self.convert_tensor_type(data_lab['image']) y0 = self.convert_tensor_type(data_lab['label_prob']) x1 = self.convert_tensor_type(data_unlab['image']) + + # for debug + # for i in range(x0.shape[0]): + # image_i = x0[i][0] + # label_i = np.argmax(y0[i], axis = 0) + # # pixw_i = pix_w[i][0] + # print(image_i.shape, label_i.shape) + # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) + # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) + # save_nd_array_as_image(image_i, image_name, reference_name = None) + # save_nd_array_as_image(label_i, label_name, reference_name = None) + # continue + if(mixup_prob > 0 and random() < mixup_prob): + x0, y0 = mixup(x0, y0) inputs = torch.cat([x0, x1], dim = 0) inputs, y0 = inputs.to(self.device), y0.to(self.device) # zero the parameter gradients self.optimizer.zero_grad() - outputs1, outputs2 = self.net(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) diff --git a/pymic/net_run/semi_sup/ssl_mcnet.py b/pymic/net_run/semi_sup/ssl_mcnet.py new file mode 100644 index 0000000..66e1034 --- /dev/null +++ b/pymic/net_run/semi_sup/ssl_mcnet.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.semi_sup import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +def sharpening(P, T = 0.1): + T = 1.0/T + P_sharpen = P**T / (P**T + (1-P)**T) + return P_sharpen + +class SSLMCNet(SSLSegAgent): + """ + Mutual Consistency Learning for semi-supervised segmentation. It requires a network + with multiple decoders for learning, such as `pymic.net.net2d.unet2d_mcnet.MCNet2D`. + + * Reference: Yicheng Wu, Zongyuan Ge et al. Mutual consistency learning for + semi-supervised medical image segmentation. + `MIA 2022. `_ + + The original code is at: https://github.com/ycwu1997/MC-Net + + :param config: (dict) A dictionary containing the configuration. + :param stage: (str) One of the stage in `train` (default), `inference` or `test`. + + .. note:: + In the configuration dictionary, in addition to the four sections (`dataset`, + `network`, `training` and `inference`) used in fully supervised learning, an + extra section `semi_supervised_learning` is needed. See :doc:`usage.ssl` for details. + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) + temperature = ssl_cfg.get('temperature', 0.1) + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass to obtain multiple predictions + outputs = self.net(inputs) + num_outputs = len(outputs) + n0 = list(x0.shape)[0] + p0 = F.softmax(outputs[0], dim=1)[:n0] + # for probability prediction and pseudo respectively + p_ori = torch.zeros((num_outputs,) + outputs[0].shape) + y_psu = torch.zeros((num_outputs,) + outputs[0].shape) + + # get supervised loss + loss_sup = 0 + for idx in range(num_outputs): + p0i = outputs[idx][:n0] + loss_sup += self.get_loss_value(data_lab, p0i, y0) + + # get pseudo labels + p_i = F.softmax(outputs[idx], dim=1) + p_ori[idx] = p_i + y_psu[idx] = sharpening(p_i, temperature) + + # get regularization loss + loss_reg = 0.0 + for i in range(num_outputs): + for j in range(num_outputs): + if (i!=j): + loss_reg += F.mse_loss(p_ori[i], y_psu[j], reduction='mean') + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice[1:].mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers diff --git a/pymic/net_run/train.py b/pymic/net_run/train.py index 426b620..50a5fb7 100644 --- a/pymic/net_run/train.py +++ b/pymic/net_run/train.py @@ -9,16 +9,21 @@ from pymic.util.parse_config import * from pymic.net_run.agent_cls import ClassificationAgent from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.agent_rec import ReconstructionAgent from pymic.net_run.semi_sup import SSLMethodDict from pymic.net_run.weak_sup import WSLMethodDict +from pymic.net_run.self_sup import SelfSupMethodDict from pymic.net_run.noisy_label import NLLMethodDict -from pymic.net_run.self_sup import SelfSLSegAgent +# from pymic.net_run.self_sup import SelfSLSegAgent def get_seg_rec_agent(config, sup_type): assert(sup_type in ['fully_sup', 'semi_sup', 'self_sup', 'weak_sup', 'noisy_label']) if(sup_type == 'fully_sup'): logging.info("\n********** Fully Supervised Learning **********\n") - agent = SegmentationAgent(config, 'train') + if config['dataset']['task_type'] == TaskType.SEGMENTATION: + agent = SegmentationAgent(config, 'train') + else: + agent = ReconstructionAgent(config, 'train') elif(sup_type == 'semi_sup'): logging.info("\n********** Semi Supervised Learning **********\n") method = config['semi_supervised_learning']['method_name'] @@ -34,28 +39,7 @@ def get_seg_rec_agent(config, sup_type): elif(sup_type == 'self_sup'): logging.info("\n********** Self Supervised Learning **********\n") method = config['self_supervised_learning']['method_name'] - if(method == "custom"): - pass - elif(method == "model_genesis"): - transforms = ['RandomFlip', 'LocalShuffling', 'NonLinearTransform', 'InOutPainting'] - genesis_cfg = { - 'randomflip_flip_depth': True, - 'randomflip_flip_height': True, - 'randomflip_flip_width': True, - 'localshuffling_probability': 0.5, - 'nonLineartransform_probability': 0.9, - 'inoutpainting_probability': 0.9, - 'inpainting_probability': 0.2 - } - config['dataset']['train_transform'].extend(transforms) - # config['dataset']['valid_transform'].extend(transforms) - config['dataset'].update(genesis_cfg) - logging_config(config['dataset']) - else: - raise ValueError("The specified method {0:} is not implemented. ".format(method) + \ - "Consider to set `self_sl_method = custom` and use customized" + \ - " transforms for self-supervised learning.") - agent = SelfSLSegAgent(config, 'train') + agent = SelfSupMethodDict[method](config, 'train') else: raise ValueError("undefined supervision type: {0:}".format(sup_type)) return agent diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py index 35c1804..6b2732a 100644 --- a/pymic/test/test_assd.py +++ b/pymic/test/test_assd.py @@ -18,7 +18,8 @@ def test_assd_2d(): plt.show() def test_assd_3d(): - img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + # img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_name = "/home/disk4t/data/heart/ACDC/preprocess/patient001_frame12_gt.nii.gz" img_obj = sitk.ReadImage(img_name) spacing = img_obj.GetSpacing() spacing = spacing[::-1] diff --git a/pymic/test/test_net2d.py b/pymic/test/test_net2d.py new file mode 100644 index 0000000..aafaf20 --- /dev/null +++ b/pymic/test/test_net2d.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net2d.unet2d import UNet2D +from pymic.net.net2d.unet2d_scse import UNet2D_ScSE + +def test_unet2d(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +def test_unet2d_scse(): + params = {'in_chns':4, + 'feature_chns':[16, 32, 64, 128, 256], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 10, 256, 256) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + if params['multiscale_pred']: + for y in out: + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + else: + print(out.shape) + +if __name__ == "__main__": + # test_unet2d() + test_unet2d_scse() \ No newline at end of file diff --git a/pymic/test/test_net3d.py b/pymic/test/test_net3d.py new file mode 100644 index 0000000..180dcff --- /dev/null +++ b/pymic/test/test_net3d.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import numpy as np +from pymic.net.net3d.unet3d import UNet3D +from pymic.net.net3d.unet3d_scse import UNet3D_ScSE +from pymic.net.net3d.unet2d5 import UNet2D5 + +def test_unet3d(): + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64], + 'dropout' : [0, 0, 0, 0.5], + 'up_mode': 2, + 'multiscale_pred': False} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'class_num': 2, + 'feature_chns':[2, 8, 32, 64, 128], + 'dropout' : [0, 0, 0, 0.4, 0.5], + 'up_mode': 3, + 'multiscale_pred': True} + Net = UNet3D(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +def test_unet3d_scse(): + params = {'in_chns':4, + 'feature_chns':[2, 8, 32, 48, 64], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2} + Net = UNet3D_ScSE(params) + Net = Net.double() + + x = np.random.rand(4, 4, 96, 96, 96) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + y = Net(xt) + print(len(y.size())) + y = y.detach().numpy() + print(y.shape) + +def test_unet2d5(): + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 2, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 2, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 32, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + + params = {'in_chns':4, + 'feature_chns':[8, 16, 32, 64, 128], + 'conv_dims': [2, 3, 3, 3, 3], + 'dropout': [0, 0, 0.3, 0.4, 0.5], + 'class_num': 2, + 'up_mode': 0, + 'multiscale_pred': True} + Net = UNet2D5(params) + Net = Net.double() + + x = np.random.rand(4, 4, 64, 128, 128) + xt = torch.from_numpy(x) + xt = torch.tensor(xt) + + out = Net(xt) + for y in out: + y = y.detach().numpy() + print(y.shape) + +if __name__ == "__main__": + # test_unet3d() + # test_unet3d_scse() + test_unet2d5() + + \ No newline at end of file diff --git a/pymic/transform/affine.py b/pymic/transform/affine.py new file mode 100644 index 0000000..552516f --- /dev/null +++ b/pymic/transform/affine.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from skimage import transform +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + +class Affine(AbstractTransform): + """ + Apply Affine Transform to an ND volume in the x-y plane. + Input shape should be [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Affine_scale_range`: (list or tuple) The range for scaling, e.g., (0.5, 2.0) + :param `Affine_shear_range`: (list or tuple) The range for shearing angle, e.g., (0, 30) + :param `Affine_rotate_range`: (list or tuple) The range for rotation, e.g., (-45, 45) + :param `Affine_output_size`: (None, list or tuple of length 2) The output size after affine transformation. + For 3D volumes, as we only apply affine transformation in x-y plane, the output slice + number will be the same as the input slice number, so only the output height and width + need to be given here, e.g., (H, W). By default (`None`), the output size will be the + same as the input size. + """ + def __init__(self, params): + super(Affine, self).__init__(params) + self.scale_range = params['Affine_scale_range'.lower()] + self.shear_range = params['Affine_shear_range'.lower()] + self.rotat_range = params['Affine_rotate_range'.lower()] + self.output_shape= params.get('Affine_output_size'.lower(), None) + self.inverse = params.get('Affine_inverse'.lower(), True) + + def _get_affine_param(self, sample, output_shape): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + input_shape = sample['image'].shape + input_dim = len(input_shape) - 1 + assert(len(output_shape) >=2) + + in_y, in_x = input_shape[-2:] + out_y, out_x = output_shape[-2:] + points = [[0, out_y], + [0, 0], + [out_x, 0], + [out_x, out_y]] + + sx = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + sy = random.random() * (self.scale_range[1] - self.scale_range[0]) + self.scale_range[0] + shx = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + shy = (random.random() * (self.shear_range[1] - self.shear_range[0]) + self.shear_range[0]) * 3.14159/180 + rot = (random.random() * (self.rotat_range[1] - self.rotat_range[0]) + self.rotat_range[0]) * 3.14159/180 + # get affine transform parameters + new_points = [] + for p in points: + x = sx * p[0] * (math.cos(rot) + math.tan(shy) * math.sin(rot)) - \ + sy * p[1] * (math.tan(shx) * math.cos(rot) + math.sin(rot)) + y = sx * p[0] * (math.sin(rot) - math.tan(shy) * math.cos(rot)) - \ + sy * p[1] * (math.tan(shx) * math.sin(rot) - math.cos(rot)) + new_points.append([x,y]) + bb_min = np.array(new_points).min(axis = 0) + bb_max = np.array(new_points).max(axis = 0) + bbx, bby = int(bb_max[0] - bb_min[0]), int(bb_max[1] - bb_min[1]) + # transform the points to the image coordinate + margin_x = in_x - bbx + margin_y = in_y - bby + p0x = random.random() * margin_x if margin_x > 0 else margin_x / 2 + p0y = random.random() * margin_y if margin_y > 0 else margin_y / 2 + dst = [[new_points[i][0] - bb_min[0] + p0x, new_points[i][1] - bb_min[1] + p0y] \ + for i in range(3)] + + tform = transform.AffineTransform() + tform.estimate(np.array(points[:3]), np.array(dst)) + # to do: need to find a solution to save the affine transform matrix + # Use the matplotlib.transforms.Affine2D function to generate transform matrices, + # and the scipy.ndimage.warp function to warp images using the transform matrices. + # The skimage AffineTransform shear functionality is weird, + # and the scipy affine_transform function for warping images swaps the X and Y axes. + # sample['Affine_Param'] = json.dumps((input_shape, tform["matrix"])) + return sample, tform + + def _apply_affine_to_ND_volume(self, image, output_shape, tform, order = 3): + """ + output_shape should only has two dimensions, e.g., (H, W) + """ + dim = len(image.shape) - 1 + if(dim == 2): + C, H, W = image.shape + output = np.zeros([C] + output_shape) + for c in range(C): + output[c] = ndimage.affine_transform(image[c], tform, + output_shape = output_shape, mode='mirror', order = order) + elif(dim == 3): + C, D, H, W = image.shape + output = np.zeros([C, D] + output_shape) + for c in range(C): + for d in range(D): + output[c,d] = ndimage.affine_transform(image[c,d], tform, + output_shape = output_shape, mode='mirror', order = order) + return output + + def __call__(self, sample): + image = sample['image'] + input_shape = sample['image'].shape + output_shape= input_shape if self.output_shape is None else self.output_shape + aff_out_shape = output_shape[-2:] + sample, tform = self._get_affine_param(sample, aff_out_shape) + image_t = self._apply_affine_to_ND_volume(image, aff_out_shape, tform) + sample['image'] = image_t + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = self._apply_affine_to_ND_volume(label, aff_out_shape, tform, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = self._apply_affine_to_ND_volume(weight, aff_out_shape, tform) + sample['pixel_weight'] = weight + return sample + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['Affine_Param'], list) or \ + isinstance(sample['Affine_Param'], tuple)): + params = json.loads(sample['Affine_Param'][0]) + else: + params = json.loads(sample['Affine_Param']) + return params + + # def inverse_transform_for_prediction(self, sample): + # params = self._get_param_for_inverse_transform(sample) + # origin_shape = params[0] + # tform = params[1] + + # predict = sample['predict'] + # if(isinstance(predict, tuple) or isinstance(predict, list)): + # output_predict = [] + # for predict_i in predict: + # aff_out_shape = origin_shape[-2:] + # output_predict_i = self._apply_affine_to_ND_volume(predict_i, + # aff_out_shape, tform.inverse) + # output_predict.append(output_predict_i) + # else: + # aff_out_shape = origin_shape[-2:] + # output_predict = self._apply_affine_to_ND_volume(predict, aff_out_shape, tform.inverse) + + # sample['predict'] = output_predict + # return sample diff --git a/pymic/transform/crop.py b/pymic/transform/crop.py index b8130f9..b821bb2 100644 --- a/pymic/transform/crop.py +++ b/pymic/transform/crop.py @@ -113,16 +113,20 @@ class CropWithBoundingBox(CenterCrop): :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index along each spatial axis. If None, calculate the start index automatically - so that the cropped region is centered at the non-zero region. + so that the cropped region is centered at the mask region defined by the threshold. :param `CropWithBoundingBox_output_size`: (None or tuple/list): Desired spatial output size. - If None, set it as the size of bounding box of non-zero region. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `True`. """ def __init__(self, params): self.start = params['CropWithBoundingBox_start'.lower()] self.output_size = params['CropWithBoundingBox_output_size'.lower()] + self.threshold = params.get('CropWithBoundingBox_threshold'.lower(), 1.0) self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True) self.task = params['task'] @@ -130,8 +134,9 @@ def _get_crop_param(self, sample): image = sample['image'] input_shape = sample['image'].shape input_dim = len(input_shape) - 1 - bb_min, bb_max = get_ND_bounding_box(image) - bb_min, bb_max = bb_min[1:], bb_max[1:] + if(self.start is None or self.output_size is None): + bb_min, bb_max = get_ND_bounding_box(image > self.threshold) + bb_min, bb_max = bb_min[1:], bb_max[1:] if(self.start is None): if(self.output_size is None): crop_min, crop_max = bb_min, bb_max @@ -153,7 +158,6 @@ def _get_crop_param(self, sample): crop_min = [0] + crop_min crop_max = list(input_shape[0:1]) + crop_max sample['CropWithBoundingBox_Param'] = json.dumps((input_shape, crop_min, crop_max)) - print("for crop", crop_min, crop_max) return sample, crop_min, crop_max def _get_param_for_inverse_transform(self, sample): @@ -213,7 +217,9 @@ class RandomCrop(CenterCrop): :param `RandomCrop_output_size`: (list/tuple) Desired output size [D, H, W] or [H, W]. The output channel is the same as the input channel. - If D is None for 3D images, the z-axis is not cropped. + If `None` is set for a certain axis, that axis will not be cropped. For example, + for 3D vlumes, (None, H, W) means only crop in 2D, and (D, None, None) means only + crop along the z axis. :param `RandomCrop_foreground_focus`: (optional, bool) If true, allow crop around the foreground. Default is False. :param `RandomCrop_foreground_ratio`: (optional, float) @@ -243,19 +249,26 @@ def _get_crop_param(self, sample): input_shape = image.shape[1:] input_dim = len(input_shape) assert(input_dim == len(self.output_size)) - - crop_margin = [input_shape[i] - self.output_size[i] for i in range(input_dim)] + + output_size = [item for item in self.output_size] + # print("crop input and output size", input_shape, output_size) + for i in range(input_dim): + if(output_size[i] is None): + output_size[i] = input_shape[i] + # print(output_size) + crop_margin = [input_shape[i] - output_size[i] for i in range(input_dim)] crop_min = [0 if item == 0 else random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + self.output_size[i] for i in range(input_dim)] + crop_max = [crop_min[i] + output_size[i] for i in range(input_dim)] - if(self.fg_focus and random.random() < self.fg_ratio): + label_exist = False if ('label' not in sample or sample['label']) is None else True + if(label_exist and self.fg_focus and random.random() < self.fg_ratio): label = sample['label'][0] if(self.mask_label is None): mask_label = np.unique(label)[1:] else: mask_label = self.mask_label random_label = random.choice(mask_label) - crop_min, crop_max = get_random_box_from_mask(label == random_label, self.output_size) + crop_min, crop_max = get_random_box_from_mask(label == random_label, output_size, mode = 1) crop_min = [0] + crop_min crop_max = [chns] + crop_max @@ -279,30 +292,45 @@ class RandomResizedCrop(CenterCrop): :param `RandomResizedCrop_output_size`: (list/tuple) Desired output size [D, H, W]. The output channel is the same as the input channel. - :param `RandomResizedCrop_scale_range`: (list/tuple) Range of scale, e.g. (0.08, 1.0). + :param `RandomResizedCrop_scale_lower_bound`: (list/tuple) Lower bound of the range of scale + for each dimension. e.g. (1.0, 0.5, 0.5). + param `RandomResizedCrop_scale_upper_bound`: (list/tuple) Upper bound of the range of scale + for each dimension. e.g. (1.0, 2.0, 2.0). :param `RandomResizedCrop_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. Currently, the inverse transform is not supported, and this transform is assumed to be used only during training stage. """ def __init__(self, params): self.output_size = params['RandomResizedCrop_output_size'.lower()] - self.scale = params['RandomResizedCrop_scale_range'.lower()] + self.scale_lower = params['RandomResizedCrop_resize_lower_bound'.lower()] + self.scale_upper = params['RandomResizedCrop_resize_upper_bound'.lower()] + self.prob = params.get('RandomResizedCrop_resize_prob'.lower(), 0.5) + self.fg_ratio = params.get('RandomResizedCrop_foreground_ratio'.lower(), 0.0) + self.mask_label = params.get('RandomResizedCrop_mask_label'.lower(), None) self.inverse = params.get('RandomResizedCrop_inverse'.lower(), False) self.task = params['Task'.lower()] assert isinstance(self.output_size, (list, tuple)) - assert isinstance(self.scale, (list, tuple)) + assert isinstance(self.scale_lower, (list, tuple)) + assert isinstance(self.scale_upper, (list, tuple)) def __call__(self, sample): image = sample['image'] channel, input_size = image.shape[0], image.shape[1:] input_dim = len(input_size) assert(input_dim == len(self.output_size)) - scale = self.scale[0] + random.random()*(self.scale[1] - self.scale[0]) - crop_size = [int(self.output_size[i] * scale) for i in range(input_dim)] + + # get the resized crop size + resize = random.random() < self.prob + if(resize): + scale = [self.scale_lower[i] + (self.scale_upper[i] - self.scale_lower[i]) * random.random() \ + for i in range(input_dim)] + crop_size = [int(self.output_size[i] * scale[i]) for i in range(input_dim)] + else: + crop_size = self.output_size + crop_margin = [input_size[i] - crop_size[i] for i in range(input_dim)] - pad_image = False - if(min(crop_margin) < 0): - pad_image = True + pad_image = min(crop_margin) < 0 + if(pad_image): # pad the image if necessary pad_size = [max(0, -crop_margin[i]) for i in range(input_dim)] pad_lower = [int(pad_size[i] / 2) for i in range(input_dim)] pad_upper = [pad_size[i] - pad_lower[i] for i in range(input_dim)] @@ -310,16 +338,29 @@ def __call__(self, sample): pad = tuple([(0, 0)] + pad) image = np.pad(image, pad, 'reflect') crop_margin = [max(0, crop_margin[i]) for i in range(input_dim)] - - crop_min = [random.randint(0, item) for item in crop_margin] - crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] + # ge the bounding box for crop + if(random.random() < self.fg_ratio): + label = sample['label'] + if(pad_image): + label = np.pad(label, pad, 'reflect') + label = label[0] + if(self.mask_label is None): + mask_label = np.unique(label)[1:] + else: + mask_label = self.mask_label + random_label = random.choice(mask_label) + crop_min, crop_max = get_random_box_from_mask(label == random_label, crop_size, mode = 1) + else: + crop_min = [random.randint(0, item) for item in crop_margin] + crop_max = [crop_min[i] + crop_size[i] for i in range(input_dim)] crop_min = [0] + crop_min crop_max = [channel] + crop_max image_t = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) - scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] - scale = [1.0] + scale - image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) + if(resize): + scale = [(self.output_size[i] + 0.0)/crop_size[i] for i in range(input_dim)] + scale = [1.0] + scale + image_t = ndimage.interpolation.zoom(image_t, scale, order = 1) sample['image'] = image_t if('label' in sample and \ @@ -329,8 +370,9 @@ def __call__(self, sample): label = np.pad(label, pad, 'reflect') crop_max[0] = label.shape[0] label = crop_ND_volume_with_bounding_box(label, crop_min, crop_max) - order = 0 if(self.task == TaskType.SEGMENTATION) else 1 - label = ndimage.interpolation.zoom(label, scale, order = order) + if(resize): + order = 0 if(self.task == TaskType.SEGMENTATION) else 1 + label = ndimage.interpolation.zoom(label, scale, order = order) sample['label'] = label if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): @@ -339,6 +381,95 @@ def __call__(self, sample): weight = np.pad(weight, pad, 'reflect') crop_max[0] = weight.shape[0] weight = crop_ND_volume_with_bounding_box(weight, crop_min, crop_max) - weight = ndimage.interpolation.zoom(weight, scale, order = 1) + if(resize): + weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight - return sample \ No newline at end of file + return sample + +class RandomSlice(AbstractTransform): + """Randomly selecting N slices from a volume + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomSlice_output_size`: (int) Desired number of slice for output. + :param `RandomSlice_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.output_size = params['RandomSlice_output_size'.lower()] + self.shuffle = params.get('RandomSlice_shuffle'.lower(), False) + self.inverse = params.get('RandomSlice_inverse'.lower(), False) + self.task = params['Task'.lower()] + + def __call__(self, sample): + image = sample['image'] + D = image.shape[1] + assert( D >= self.output_size) + slice_idx = list(range(D)) + if(self.shuffle): + random.shuffle(slice_idx) + slice_idx = slice_idx[:self.output_size] + else: + d0 = random.randint(0, D - self.output_size) + d1 = d0 + self.output_size + slice_idx = slice_idx[d0:d1] + sample['image'] = image[:, slice_idx, :, :] + + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + sample['label'] = label[:, slice_idx, :, :] + + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + sample['pixel_weight'] = weight[:, slice_idx, :, :] + + return sample + +class CropHumanRegionFromCT(CenterCrop): + """ + Crop the human region from a CT volume. + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `CropWithBoundingBox_start`: (None, or list/tuple) The start index + along each spatial axis. If None, calculate the start index automatically + so that the cropped region is centered at the mask region defined by the threshold. + :param `CropWithBoundingBox_output_size`: (None or tuple/list): + Desired spatial output size. + If None, set it as the size of bounding box of the mask region defined by the threshold. + :param `CropWithBoundingBox_threshold`: (None or float): + Threshold for obtaining a mask. This is used only when + `CropWithBoundingBox_start` is None. Default is 1.0 + :param `CropWithBoundingBox_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + self.threshold_i = params.get('CropHumanRegionFromCT_intensity_threshold'.lower(), -600) + self.threshold_z = params.get('CropHumanRegionFromCT_zaxis_threshold'.lower(), 0.5) + self.inverse = params.get('CropHumanRegionFromCT_inverse'.lower(), True) + self.task = params['task'] + + def _get_crop_param(self, sample): + image = sample['image'] + input_shape = image.shape + mask = np.asarray(image[0] > self.threshold_i) + mask2d = np.mean(mask, axis = 0) > self.threshold_z + se = np.ones([3,3]) + mask2d = ndimage.binary_opening(mask2d, se, iterations = 2) + mask2d = get_largest_k_components(mask2d, 1) + bbmin, bbmax = get_ND_bounding_box(mask2d, margin = [0, 0]) + crop_min = [0, 0] + bbmin + crop_max = list(input_shape[:2]) + bbmax + sample['CropHumanRegionFromCT_Param'] = json.dumps((input_shape, crop_min, crop_max)) + return sample, crop_min, crop_max + + def _get_param_for_inverse_transform(self, sample): + if(isinstance(sample['CropHumanRegionFromCT_Param'], list) or \ + isinstance(sample['CropHumanRegionFromCT_Param'], tuple)): + params = json.loads(sample['CropHumanRegionFromCT_Param'][0]) + else: + params = json.loads(sample['CropHumanRegionFromCT_Param']) + return params \ No newline at end of file diff --git a/pymic/transform/extract_channel.py b/pymic/transform/extract_channel.py new file mode 100644 index 0000000..c4974be --- /dev/null +++ b/pymic/transform/extract_channel.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic import TaskType +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class ExtractChannel(AbstractTransform): + """ Random flip the image. The shape is [C, D, H, W] or [C, H, W]. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `RandomFlip_flip_depth`: (bool) + Random flip along depth axis or not, only used for 3D images. + :param `RandomFlip_flip_height`: (bool) Random flip along height axis or not. + :param `RandomFlip_flip_width`: (bool) Random flip along width axis or not. + :param `RandomFlip_inverse`: (optional, bool) Is inverse transform needed for inference. + Default is `True`. + """ + def __init__(self, params): + super(ExtractChannel, self).__init__(params) + self.channels = params['ExtractChannel_channels'.lower()] + self.inverse = params.get('ExtractChannel_inverse'.lower(), False) + + def __call__(self, sample): + image = sample['image'] + image_extract = [] + for i in self.channels: + image_extract.append(image[i]) + sample['image'] = np.asarray(image_extract) + return sample diff --git a/pymic/transform/flip.py b/pymic/transform/flip.py index 24cafb4..6ea017c 100644 --- a/pymic/transform/flip.py +++ b/pymic/transform/flip.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * @@ -54,7 +53,7 @@ def __call__(self, sample): image_t = np.flip(image, flip_axis).copy() sample['image'] = image_t if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): sample['label'] = np.flip(sample['label'] , flip_axis).copy() if('pixel_weight' in sample and \ self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py index 1a13190..2b19ebc 100644 --- a/pymic/transform/intensity.py +++ b/pymic/transform/intensity.py @@ -37,7 +37,7 @@ def bezier_curve(points, nTimes=1000): t = np.linspace(0.0, 1.0, nTimes) - polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints) ]) + polynomial_array = np.array([ bernstein_poly(i, nPoints-1, t) for i in range(0, nPoints)]) xvals = np.dot(xPoints, polynomial_array) yvals = np.dot(yPoints, polynomial_array) @@ -62,6 +62,7 @@ def __init__(self, params): self.channels = params['IntensityClip_channels'.lower()] self.lower = params.get('IntensityClip_lower'.lower(), None) self.upper = params.get('IntensityClip_upper'.lower(), None) + self.perct = params.get('IntensityClip_percentile_mode'.lower(), False) self.inverse = params.get('IntensityClip_inverse'.lower(), False) def __call__(self, sample): @@ -72,8 +73,12 @@ def __call__(self, sample): lower_c, upper_c = lower[chn], upper[chn] if(lower_c is None): lower_c = np.percentile(image[chn], 0.05) + elif(self.perct): + lower_c = np.percentile(image[chn], lower_c) if(upper_c is None): - upper_c = np.percentile(image[chn, 99.95]) + upper_c = np.percentile(image[chn], 99.95) + elif(self.perct): + upper_c = np.percentile(image[chn], upper_c) image[chn] = np.clip(image[chn], lower_c, upper_c) sample['image'] = image return sample @@ -95,23 +100,28 @@ class GammaCorrection(AbstractTransform): """ def __init__(self, params): super(GammaCorrection, self).__init__(params) - self.channels = params['GammaCorrection_channels'.lower()] - self.gamma_min = params['GammaCorrection_gamma_min'.lower()] - self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.channels = params.get('GammaCorrection_channels'.lower(), None) + self.gamma_min = params.get('GammaCorrection_gamma_min'.lower(), 0.7) + self.gamma_max = params.get('GammaCorrection_gamma_max'.lower(), 1.5) + self.flip_prob = params.get('GammaCorrection_intensity_flip_probability'.lower(), 0.0) self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) self.inverse = params.get('GammaCorrection_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample image= sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: + if(np.random.uniform() > self.prob): + continue gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min img_c = image[chn] v_min = img_c.min() v_max = img_c.max() if(v_min < v_max): img_c = (img_c - v_min)/(v_max - v_min) + if(np.random.uniform() < self.flip_prob): + img_c = 1.0 - img_c img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min image[chn] = img_c @@ -135,20 +145,21 @@ class GaussianNoise(AbstractTransform): """ def __init__(self, params): super(GaussianNoise, self).__init__(params) - self.channels = params['GaussianNoise_channels'.lower()] + self.channels = params.get('GaussianNoise_channels'.lower(), None) self.mean = params['GaussianNoise_mean'.lower()] self.std = params['GaussianNoise_std'.lower()] self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) self.inverse = params.get('GaussianNoise_inverse'.lower(), False) def __call__(self, sample): - if(np.random.uniform() > self.prob): - return sample - image= sample['image'] + image = sample['image'] + if(self.channels is None): + self.channels = range(image.shape[0]) for chn in self.channels: - img_c = image[chn] - noise = np.random.normal(self.mean, self.std, img_c.shape) - image[chn] = img_c + noise + if(np.random.uniform() < self.prob): + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise sample['image'] = image return sample @@ -171,21 +182,55 @@ def __call__(self, sample): class NonLinearTransform(AbstractTransform): def __init__(self, params): super(NonLinearTransform, self).__init__(params) - self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.channels = params.get('NonLinearTransform_channels'.lower(), None) self.prob = params.get('NonLinearTransform_probability'.lower(), 0.5) + self.inverse = params.get('NonLinearTransform_inverse'.lower(), False) + self.block_range = params.get('NonLinearTransform_block_range'.lower(), None) + self.block_size = params.get('NonLinearTransform_block_size'.lower(), [8, 16, 16]) + - def __call__(self, sample): - if(random.random() > self.prob): - return sample - - image= sample['image'] + def __apply_nonlinear_transform(self, img): + """ + the input img should be normlized to [0, 1]""" points = [[0, 0], [random.random(), random.random()], [random.random(), random.random()], [1, 1]] - xvals, yvals = bezier_curve(points, nTimes=100000) - if random.random() < 0.5: # Half change to get flip + xvals, yvals = bezier_curve(points, nTimes=10000) + if random.random() < 0.5: # Half chance to get flip xvals = np.sort(xvals) else: xvals, yvals = np.sort(xvals), np.sort(yvals) - image = np.interp(image, xvals, yvals) + + img = np.interp(img, xvals, yvals) + return img + + def __call__(self, sample): + if(random.random() > self.prob): + return sample + + image = sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + channels = self.channels if self.channels is not None else range(image.shape[0]) + for chn in channels: + # normalize the image intensity to [0, 1] before the non-linear tranform + img_c = image[chn] + v_min, v_max = img_c.min(), img_c.max() + if(v_min < v_max): + img_c = (img_c - v_min)/(v_max - v_min) + if(self.block_range is None): # apply non-linear transform to the entire image + img_c = self.__apply_nonlinear_transform(img_c) + else: # non-linear transform to random blocks + img_c_sr = copy.deepcopy(img_c) + for n in range(self.block_range[0], self.block_range[1]): + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ + for i in range(img_dim)] + window = img_c_sr[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + img_c[coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = \ + self.__apply_nonlinear_transform(window) + image[chn] = img_c * (v_max - v_min) + v_min sample['image'] = image return sample @@ -197,9 +242,8 @@ def __init__(self, params): super(LocalShuffling, self).__init__(params) self.inverse = params.get('LocalShuffling_inverse'.lower(), False) self.prob = params.get('LocalShuffling_probability'.lower(), 0.5) - self.block_range = params.get('LocalShuffling_block_range'.lower(), (5000, 10000)) - self.block_size_min = params.get('LocalShuffling_block_size_min'.lower(), None) - self.block_size_max = params.get('LocalShuffling_block_size_max'.lower(), None) + self.block_range = params.get('LocalShuffling_block_range'.lower(), [40, 80]) + self.block_size = params.get('LocalShuffling_block_size'.lower(), [4, 8, 8]) def __call__(self, sample): if(random.random() > self.prob): @@ -210,49 +254,33 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) img_out = copy.deepcopy(image) - if(self.block_size_min is None): - block_size_min = [2] * img_dim - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//10 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max + block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(0, img_shape[1+i] - block_size[i]) \ + coord_min = [random.randint(0, img_shape[1+i] - self.block_size[i]) \ for i in range(img_dim)] if(img_dim == 2): - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] - n_pixels = block_size[0] * block_size[1] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] + n_pixels = self.block_size[0] * self.block_size[1] else: - window = image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] - n_pixels = block_size[0] * block_size[1] * block_size[2] + window = image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] + n_pixels = self.block_size[0] * self.block_size[1] * self.block_size[2] window = np.reshape(window, [-1, n_pixels]) np.random.shuffle(np.transpose(window)) window = np.transpose(window) if(img_dim == 2): - window = np.reshape(window, [-1, block_size[0], block_size[1]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = window else: - window = np.reshape(window, [-1, block_size[0], block_size[1], block_size[2]]) - img_out[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = window + window = np.reshape(window, [-1, self.block_size[0], self.block_size[1], self.block_size[2]]) + img_out[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = window sample['image'] = img_out return sample @@ -264,10 +292,9 @@ def __init__(self, params): super(InPainting, self).__init__(params) self.inverse = params.get('InPainting_inverse'.lower(), False) self.prob = params.get('InPainting_probability'.lower(), 0.5) - self.block_range = params.get('InPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('InPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('InPainting_block_size_max'.lower(), None) - + self.block_range = params.get('InPainting_block_range'.lower(), (20, 40)) + self.block_size = params.get('InPainting_block_size'.lower(), [8, 16, 16]) + def __call__(self, sample): if(random.random() > self.prob): return sample @@ -277,38 +304,21 @@ def __call__(self, sample): img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - if(self.block_size_min is None): - block_size_min = [img_shape[1+i]//6 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim - else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min - - if(self.block_size_max is None): - block_size_max = [img_shape[1+i]//3 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) - for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] - coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ + for n in range(block_num): + coord_min = [random.randint(3, img_shape[1+i] - self.block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): - random_block = np.random.rand(img_shape[0], block_size[0], block_size[1]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], self.block_size[1]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1]] = random_block else: - random_block = np.random.rand(img_shape[0], block_size[0], - block_size[1], block_size[2]) - image[:, coord_min[0]:coord_min[0] + block_size[0], - coord_min[1]:coord_min[1] + block_size[1], - coord_min[2]:coord_min[2] + block_size[2]] = random_block + random_block = np.random.rand(img_shape[0], self.block_size[0], + self.block_size[1], self.block_size[2]) * 2 -1 + image[:, coord_min[0]:coord_min[0] + self.block_size[0], + coord_min[1]:coord_min[1] + self.block_size[1], + coord_min[2]:coord_min[2] + self.block_size[2]] = random_block sample['image'] = image return sample @@ -320,9 +330,8 @@ def __init__(self, params): super(OutPainting, self).__init__(params) self.inverse = params.get('OutPainting_inverse'.lower(), False) self.prob = params.get('OutPainting_probability'.lower(), 0.5) - self.block_range = params.get('OutPainting_block_range'.lower(), (1, 6)) - self.block_size_min = params.get('OutPainting_block_size_min'.lower(), None) - self.block_size_max = params.get('OutPainting_block_size_max'.lower(), None) + self.block_range = params.get('OutPainting_block_range'.lower(), (2, 8)) + self.block_size = params.get('OutPainting_block_size'.lower(), None) def __call__(self, sample): if(random.random() > self.prob): @@ -332,28 +341,18 @@ def __call__(self, sample): img_shape = image.shape img_dim = len(img_shape) - 1 assert(img_dim == 2 or img_dim == 3) - img_out = np.random.rand(*img_shape) + img_out = np.random.rand(*img_shape) * 2 -1 - if(self.block_size_min is None): - block_size_min = [img_shape[1+i] - 4 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_min = [self.block_size_min] * img_dim + if(self.block_size is None): + margin = [16, 32, 32] + block_size = [img_shape[1+i] - margin[i] for i in range(img_dim)] else: - assert(len(self.block_size_min) == img_dim) - block_size_min = self.block_size_min + assert(len(self.block_size) == img_dim) + block_size = self.block_size - if(self.block_size_max is None): - block_size_max = [img_shape[1+i] - 3 * img_shape[1+i]//7 for i in range(img_dim)] - elif(isinstance(self.block_size_min, int)): - block_size_max = [self.block_size_max] * img_dim - else: - assert(len(self.block_size_max) == img_dim) - block_size_max = self.block_size_max block_num = random.randint(self.block_range[0], self.block_range[1]) for n in range(block_num): - block_size = [random.randint(block_size_min[i], block_size_max[i]) \ - for i in range(img_dim)] coord_min = [random.randint(3, img_shape[1+i] - block_size[i] - 3) \ for i in range(img_dim)] if(img_dim == 2): @@ -380,8 +379,8 @@ def __init__(self, params): self.inverse = params.get('InOutPainting_inverse'.lower(), False) self.prob = params.get('InOutPainting_probability'.lower(), 0.5) self.in_prob = params.get('InPainting_probability'.lower(), 0.5) - params['InPainting_probability'] = 1.0 - params['outPainting_probability'] = 1.0 + params['InPainting_probability'.lower()] = 1.0 + params['OutPainting_probability'.lower()] = 1.0 self.inpaint = InPainting(params) self.outpaint = OutPainting(params) @@ -392,4 +391,38 @@ def __call__(self, sample): sample = self.inpaint(sample) else: sample = self.outpaint(sample) + return sample + +class PatchSwaping(AbstractTransform): + """ + Apply patch swaping for context restoration in self-supervised learning. + Reference: Liang Chen et al., Self-supervised learning for medical image analysis + using image context restoration, Medical Image Analysis, 2019. + """ + def __init__(self, params): + super(PatchSwaping, self).__init__(params) + self.block_range = params.get('PatchSwaping_block_range'.lower(), (10, 20)) + self.block_size = params.get('PatchSwaping_block_size'.lower(), [8, 16, 16]) + self.inverse = params.get('PatchSwaping_inverse'.lower(), False) + + def __call__(self, sample): + image= sample['image'] + img_shape = image.shape + img_dim = len(img_shape) - 1 + assert(img_dim == 2 or img_dim == 3) + img_out = copy.deepcopy(image) + + block_num = random.randint(self.block_range[0], self.block_range[1]) + for t in range(block_num): + pos_a0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_b0 = [random.randint(0, img_shape[-3+i] - self.block_size[i]) for i in range(img_dim)] + pos_a1 = [pos_a0[i] + self.block_size[i] for i in range(img_dim)] + pos_b1 = [pos_b0[i] + self.block_size[i] for i in range(img_dim)] + img_out[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] = \ + image[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] + img_out[:, pos_b0[0]:pos_b1[0], pos_b0[1]:pos_b1[1], pos_b0[2]:pos_b1[2]] = \ + image[:, pos_a0[0]:pos_a1[0], pos_a0[1]:pos_a1[1], pos_a0[2]:pos_a1[2]] + + sample['image'] = img_out + sample['label'] = image return sample \ No newline at end of file diff --git a/pymic/transform/mix.py b/pymic/transform/mix.py index fe2f315..6efed6a 100644 --- a/pymic/transform/mix.py +++ b/pymic/transform/mix.py @@ -63,4 +63,91 @@ def __call__(self, sample): coord_min[1]:coord_min[1] + block_size[1], coord_min[2]:coord_min[2] + block_size[2]] = random_block sample['image'] = image - return sample \ No newline at end of file + return sample + +class PatchMix(AbstractTransform): + """ + In-painting of an input image, used for self-supervised learning + """ + def __init__(self, params): + super(PatchMix, self).__init__(params) + self.inverse = params.get('PatchMix_inverse'.lower(), False) + self.threshold = params.get('PatchMix_threshold'.lower(), 0) + self.crop_size = params.get('PatchMix_crop_size'.lower(), [64, 128, 128]) + self.fg_cls_num = params.get('PatchMix_cls_num'.lower(), [4, 40]) + self.patch_num_range= params.get('PatchMix_patch_range'.lower(), [4, 40]) + self.patch_size_min = params.get('PatchMix_patch_size_min'.lower(), [4, 4, 4]) + self.patch_size_max = params.get('PatchMix_patch_size_max'.lower(), [20, 40, 40]) + + def __call__(self, sample): + x0 = self._random_crop_and_flip(sample) + x1 = self._random_crop_and_flip(sample) + C, D, H, W = x0.shape + # generate mask + fg_mask = np.zeros_like(x0, np.uint8) + patch_num = random.randint(self.patch_num_range[0], self.patch_num_range[1]) + for patch in range(patch_num): + d = random.randint(self.patch_size_min[0], self.patch_size_max[0]) + h = random.randint(self.patch_size_min[1], self.patch_size_max[1]) + w = random.randint(self.patch_size_min[2], self.patch_size_max[2]) + d_c = random.randint(0, D) + h_c = random.randint(0, H) + w_c = random.randint(0, W) + d0, d1 = max(0, d_c - d // 2), min(D, d_c + d // 2) + h0, h1 = max(0, h_c - h // 2), min(H, h_c + h // 2) + w0, w1 = max(0, w_c - w // 2), min(W, w_c + w // 2) + temp_m = np.ones([C, d1-d0, h1-h0, w1-w0]) * random.randint(1, self.fg_cls_num) + fg_mask[:, d0:d1, h0:h1, w0:w1] = temp_m + fg_w = fg_mask * 1.0 / self.fg_cls_num + x_fuse = fg_w*x0 + (1.0 - fg_w)*x1 # x1 is used as background + + sample['image'] = x_fuse + sample['label'] = fg_mask + return sample + + def _random_crop_and_flip(self, sample): + image = sample['image'] + input_dim = len(image.shape) - 1 + assert(input_dim == 3) + C, D, H, W = image.shape + + half_size = [x // 2 for x in self.crop_size] + dc = random.randint(half_size[0], D - half_size[0]) + image2d = image[0, dc, :, :] + mask2d = np.zeros_like(image2d) + mask2d[half_size[1]:H+1-half_size[1], half_size[2]:W+1-half_size[2]] = \ + np.ones([H-self.crop_size[1]+1, W-self.crop_size[2]+1]) + if('label' in sample): + temp_mask = sample['label'][0, dc, :, :] > 0 + mask2d = temp_mask * mask2d + elif(self.threshold is not None): + temp_mask = image2d > self.threshold + se = np.ones([3,3]) + temp_mask = ndimage.binary_opening(temp_mask, se, iterations = 2) + temp_mask = get_largest_k_components(temp_mask, 1) + mask2d = temp_mask * mask2d + + indices = np.where(mask2d) + n = random.randint(0, len(indices[0])-1) + center = [indices[i][n] for i in range(2)] + crop_min = [dc - half_size[0], center[0]-half_size[1], center[1] - half_size[2]] + crop_max = [crop_min[i] + self.crop_size[i] for i in range(input_dim)] + crop_min = [0] + crop_min + crop_max = [C] + crop_max + x = crop_ND_volume_with_bounding_box(image, crop_min, crop_max) + + flip_axis = [] + if(random.random() > 0.5): + flip_axis.append(-1) + if(random.random() > 0.5): + flip_axis.append(-2) + if(random.random() > 0.5): + flip_axis.append(-3) + if(len(flip_axis) > 0): + x = np.flip(x, flip_axis).copy() + + if(x.shape[1] == 63): + print("crop shape == 63", x.shape) + print(sample['names']) + print(image.shape, crop_min, crop_max) + return x \ No newline at end of file diff --git a/pymic/transform/normalize.py b/pymic/transform/normalize.py index 77852d2..5f0e4ec 100644 --- a/pymic/transform/normalize.py +++ b/pymic/transform/normalize.py @@ -24,35 +24,38 @@ class NormalizeWithMeanStd(AbstractTransform): :param `NormalizeWithMeanStd_std`: (list/tuple or None) The std values along each specified channel. If None, the std values are calculated automatically. - :param `NormalizeWithMeanStd_ignore_non_positive`: (optional, bool) - Only used when mean and std are not given. Default is False. - If True, calculate mean and std in the positive region for normalization, - and set non-positive region to random. If False, calculate - the mean and std values in the entire image region. + :param `NormalizeWithMeanStd_mask_threshold`: (optional, float) + Only used when mean and std are not given. Default is 1.0. + Calculate mean and std in the mask region where the intensity is higher than the mask. + :param `NormalizeWithMeanStd_set_background_to_random`: (optional, bool) + Set background region to random or not, and only applicable when + `NormalizeWithMeanStd_mask_threshold` is not None. Default is True. :param `NormalizeWithMeanStd_inverse`: (optional, bool) Is inverse transform needed for inference. Default is `False`. """ def __init__(self, params): super(NormalizeWithMeanStd, self).__init__(params) - self.chns = params['NormalizeWithMeanStd_channels'.lower()] + self.chns = params.get('NormalizeWithMeanStd_channels'.lower(), None) self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None) self.std = params.get('NormalizeWithMeanStd_std'.lower(), None) - self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False) - self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) + self.mask_thrd = params.get('NormalizeWithMeanStd_mask_threshold'.lower(), 1.0) + self.bg_random = params.get('NormalizeWithMeanStd_set_background_to_random'.lower(), True) + self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False) def __call__(self, sample): image= sample['image'] - chns = self.chns if self.chns is not None else range(image.shape[0]) + if(self.chns is None): + self.chns = range(image.shape[0]) if(self.mean is None): - self.mean = [None] * len(chns) - self.std = [None] * len(chns) + self.mean = [None] * len(self.chns) + self.std = [None] * len(self.chns) - for i in range(len(chns)): - chn = chns[i] + for i in range(len(self.chns)): + chn = self.chns[i] chn_mean, chn_std = self.mean[i], self.std[i] if(chn_mean is None): - if(self.ingore_np): - pixels = image[chn][image[chn] > 0] + if(self.mask_thrd is not None): + pixels = image[chn][image[chn] > self.mask_thrd] if(len(pixels) > 0): chn_mean, chn_std = pixels.mean(), pixels.std() + 1e-5 else: @@ -62,16 +65,16 @@ def __call__(self, sample): chn_norm = (image[chn] - chn_mean)/chn_std - if(self.ingore_np): + if(self.mask_thrd is not None and self.bg_random): chn_random = np.random.normal(0, 1, size = chn_norm.shape) - chn_norm[image[chn] <= 0] = chn_random[image[chn] <= 0] + chn_norm[image[chn] <= self.mask_thrd] = chn_random[image[chn] <=self.mask_thrd] image[chn] = chn_norm sample['image'] = image return sample class NormalizeWithMinMax(AbstractTransform): - """Nomralize the image to [0, 1]. The shape should be [C, D, H, W] or [C, H, W]. + """Nomralize the image to [-1, 1]. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the following fields: @@ -109,13 +112,13 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample class NormalizeWithPercentiles(AbstractTransform): - """Nomralize the image to [0, 1] with percentiles for given channels. + """Nomralize the image to [-1, 1] with percentiles for given channels. The shape should be [C, D, H, W] or [C, H, W]. The arguments should be written in the `params` dictionary, and it has the @@ -149,7 +152,7 @@ def __call__(self, sample): img_chn[img_chn < v0] = v0 img_chn[img_chn > v1] = v1 - img_chn = (img_chn - v0) / (v1 - v0) + img_chn = 2.0* (img_chn - v0) / (v1 - v0) -1.0 image[chn] = img_chn sample['image'] = image return sample \ No newline at end of file diff --git a/pymic/transform/pad.py b/pymic/transform/pad.py index c9b75fe..8624aa2 100644 --- a/pymic/transform/pad.py +++ b/pymic/transform/pad.py @@ -6,7 +6,6 @@ import math import random import numpy as np -from scipy import ndimage from pymic import TaskType from pymic.transform.abstract_transform import AbstractTransform from pymic.util.image_process import * diff --git a/pymic/transform/rescale.py b/pymic/transform/rescale.py index 355712e..2896a4e 100644 --- a/pymic/transform/rescale.py +++ b/pymic/transform/rescale.py @@ -100,27 +100,26 @@ def __init__(self, params): self.ratio0 = params["RandomRescale_lower_bound".lower()] self.ratio1 = params["RandomRescale_upper_bound".lower()] self.prob = params.get('RandomRescale_probability'.lower(), 0.5) - self.inverse = params.get("RandomRescale_inverse".lower(), True) + self.inverse = params.get("RandomRescale_inverse".lower(), False) assert isinstance(self.ratio0, (float, list, tuple)) assert isinstance(self.ratio1, (float, list, tuple)) def __call__(self, sample): - # if(random.random() > self.prob): - # print("rescale not started") - # sample['RandomRescale_triggered'] = False - # return sample - # else: - # print("rescale started") - # sample['RandomRescale_triggered'] = True + image = sample['image'] input_shape = image.shape input_dim = len(input_shape) - 1 - + assert(input_dim == len(self.ratio0) and input_dim == len(self.ratio1)) + if isinstance(self.ratio0, (list, tuple)): - for i in range(len(self.ratio0)): + for i in range(input_dim): + if(self.ratio0[i] is None): + self.ratio0[i] = 1.0 + if(self.ratio1[i] is None): + self.ratio1[i] = 1.0 assert(self.ratio0[i] <= self.ratio1[i]) scale = [self.ratio0[i] + random.random()*(self.ratio1[i] - self.ratio0[i]) \ - for i in range(len(self.ratio0))] + for i in range(input_dim)] else: scale = self.ratio0 + random.random()*(self.ratio1 - self.ratio0) scale = [scale] * input_dim @@ -130,12 +129,12 @@ def __call__(self, sample): sample['image'] = image_t sample['RandomRescale_Param'] = json.dumps(input_shape) if('label' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): label = sample['label'] label = ndimage.interpolation.zoom(label, scale, order = 0) sample['label'] = label if('pixel_weight' in sample and \ - self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): weight = sample['pixel_weight'] weight = ndimage.interpolation.zoom(weight, scale, order = 1) sample['pixel_weight'] = weight @@ -143,8 +142,6 @@ def __call__(self, sample): return sample def inverse_transform_for_prediction(self, sample): - if(not sample['RandomRescale_triggered']): - return sample if(isinstance(sample['RandomRescale_Param'], list) or \ isinstance(sample['RandomRescale_Param'], tuple)): origin_shape = json.loads(sample['RandomRescale_Param'][0]) @@ -157,6 +154,77 @@ def inverse_transform_for_prediction(self, sample): i in range(origin_dim)] scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) + sample['predict'] = output_predict + return sample + + +class Resample(Rescale): + """Resample the image to a given spatial resolution. + + The arguments should be written in the `params` dictionary, and it has the + following fields: + + :param `Rescale_output_size`: (list/tuple or int) The output size along each spatial axis, + such as [D, H, W] or [H, W]. If D is None, the input image is only reslcaled in 2D. + If int, the smallest axis is matched to output_size keeping aspect ratio the same + as the input. + :param `Rescale_inverse`: (optional, bool) + Is inverse transform needed for inference. Default is `True`. + """ + def __init__(self, params): + super(Rescale, self).__init__(params) + self.output_spacing = params["Resample_output_spacing".lower()] + self.ignore_zspacing= params.get("Resample_ignore_zspacing_range".lower(), None) + self.inverse = params.get("Resample_inverse".lower(), True) + # assert isinstance(self.output_size, (int, list, tuple)) + + def __call__(self, sample): + image = sample['image'] + input_shape = image.shape + input_dim = len(input_shape) - 1 + spacing = sample['spacing'] + out_spacing = [item for item in self.output_spacing] + for i in range(input_dim): + out_spacing[i] = spacing[i] if out_spacing[i] is None else out_spacing[i] + if(self.ignore_zspacing is not None): + if(spacing[0] > self.ignore_zspacing[0] and spacing[0] < self.ignore_zspacing[1]): + out_spacing[0] = spacing[0] + scale = [spacing[i] / out_spacing[i] for i in range(input_dim)] + scale = [1.0] + scale + + image_t = ndimage.interpolation.zoom(image, scale, order = 1) + + sample['image'] = image_t + sample['spacing'] = out_spacing + sample['Resample_origin_shape'] = json.dumps(input_shape) + if('label' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + label = sample['label'] + label = ndimage.interpolation.zoom(label, scale, order = 0) + sample['label'] = label + if('pixel_weight' in sample and \ + self.task in [TaskType.SEGMENTATION, TaskType.RECONSTRUCTION]): + weight = sample['pixel_weight'] + weight = ndimage.interpolation.zoom(weight, scale, order = 1) + sample['pixel_weight'] = weight + + return sample + + def inverse_transform_for_prediction(self, sample): + if(isinstance(sample['Resample_origin_shape'], list) or \ + isinstance(sample['Resample_origin_shape'], tuple)): + origin_shape = json.loads(sample['Resample_origin_shape'][0]) + else: + origin_shape = json.loads(sample['Resample_origin_shape']) + + origin_dim = len(origin_shape) - 1 + predict = sample['predict'] + input_shape = predict.shape + scale = [(origin_shape[1:][i] + 0.0)/input_shape[2:][i] for \ + i in range(origin_dim)] + scale = [1.0, 1.0] + scale + output_predict = ndimage.interpolation.zoom(predict, scale, order = 1) sample['predict'] = output_predict return sample \ No newline at end of file diff --git a/pymic/transform/rotate.py b/pymic/transform/rotate.py index 65e5328..5f85e28 100644 --- a/pymic/transform/rotate.py +++ b/pymic/transform/rotate.py @@ -34,8 +34,8 @@ class RandomRotate(AbstractTransform): def __init__(self, params): super(RandomRotate, self).__init__(params) self.angle_range_d = params['RandomRotate_angle_range_d'.lower()] - self.angle_range_h = params['RandomRotate_angle_range_h'.lower()] - self.angle_range_w = params['RandomRotate_angle_range_w'.lower()] + self.angle_range_h = params.get('RandomRotate_angle_range_h'.lower(), None) + self.angle_range_w = params.get('RandomRotate_angle_range_w'.lower(), None) self.prob = params.get('RandomRotate_probability'.lower(), 0.5) self.inverse = params.get('RandomRotate_inverse'.lower(), True) diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index 5dac73b..ed5ad0c 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -31,6 +31,7 @@ """ from __future__ import print_function, division +from pymic.transform.affine import * from pymic.transform.intensity import * from pymic.transform.flip import * from pymic.transform.pad import * @@ -40,13 +41,16 @@ from pymic.transform.threshold import * from pymic.transform.normalize import * from pymic.transform.crop import * +from pymic.transform.mix import * from pymic.transform.label_convert import * TransformDict = { + 'Affine': Affine, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CropWithForeground': CropWithForeground, + 'CropHumanRegionFromCT': CropHumanRegionFromCT, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, 'GammaCorrection': GammaCorrection, @@ -64,6 +68,7 @@ 'NormalizeWithPercentiles': NormalizeWithPercentiles, 'PartialLabelToProbability':PartialLabelToProbability, 'RandomCrop': RandomCrop, + 'RandomSlice': RandomSlice, 'RandomResizedCrop': RandomResizedCrop, 'RandomRescale': RandomRescale, 'RandomTranspose': RandomTranspose, @@ -71,8 +76,11 @@ 'RandomRotate': RandomRotate, 'ReduceLabelDim': ReduceLabelDim, 'Rescale': Rescale, + 'Resample': Resample, 'SelfReconstructionLabel': SelfReconstructionLabel, 'MaskedImageModelingLabel': MaskedImageModelingLabel, 'OutPainting': OutPainting, 'Pad': Pad, + 'PatchSwaping':PatchSwaping, + 'PatchMix': PatchMix } diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 836401d..a9b114b 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -108,22 +108,30 @@ def binary_hd95(s, g, spacing = None): """ s_edge = get_edge_points(s) g_edge = get_edge_points(g) - image_dim = len(s.shape) - assert(image_dim == len(g.shape)) - if(spacing == None): - spacing = [1.0] * image_dim + ns = s_edge.sum() + ng = g_edge.sum() + if(ns + ng == 0): + hd95 = 0.0 + elif(ns * ng == 0): + hd95 = 100.0 else: - assert(image_dim == len(spacing)) - s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) - g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) - - dist_list1 = s_dis[g_edge > 0] - dist_list1 = sorted(dist_list1) - dist1 = dist_list1[int(len(dist_list1)*0.95)] - dist_list2 = g_dis[s_edge > 0] - dist_list2 = sorted(dist_list2) - dist2 = dist_list2[int(len(dist_list2)*0.95)] - return max(dist1, dist2) + image_dim = len(s.shape) + assert(image_dim == len(g.shape)) + if(spacing == None): + spacing = [1.0] * image_dim + else: + assert(image_dim == len(spacing)) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + + dist_list1 = s_dis[g_edge > 0] + dist_list1 = sorted(dist_list1) + dist1 = dist_list1[int(len(dist_list1)*0.95)] + dist_list2 = g_dis[s_edge > 0] + dist_list2 = sorted(dist_list2) + dist2 = dist_list2[int(len(dist_list2)*0.95)] + hd95 = max(dist1, dist2) + return hd95 def binary_assd(s, g, spacing = None): @@ -150,9 +158,14 @@ def binary_assd(s, g, spacing = None): ns = s_edge.sum() ng = g_edge.sum() - s_dis_g_edge = s_dis * g_edge - g_dis_s_edge = g_dis * s_edge - assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) + if(ns + ng == 0): + assd = 0.0 + elif(ns*ng == 0): + assd = 20.0 + else: + s_dis_g_edge = s_dis * g_edge + g_dis_s_edge = g_dis * s_edge + assd = (s_dis_g_edge.sum() + g_dis_s_edge.sum()) / (ns + ng) return assd # relative volume error evaluation @@ -315,8 +328,10 @@ def evaluation(config): # save the result as csv if(output_name is None): - output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) - with open(output_name, mode='w') as csv_file: + metric_output_name = "{0:}/eval_{1:}.csv".format(seg_root, metric) + else: + metric_output_name = output_name + with open(metric_output_name, mode='w') as csv_file: csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"',quoting=csv.QUOTE_MINIMAL) head = ['image'] + ["class_{0:}".format(i) for i in label_list] diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 8ae8e80..c813e5d 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -167,18 +167,46 @@ def random_crop_ND_volume(volume, out_shape): crop_volume = crop_ND_volume_with_bounding_box(image_pad, bb_min, bb_max) return crop_volume -def get_random_box_from_mask(mask, out_shape): - indexes = np.where(mask) - voxel_num = len(indexes[0]) - dim = len(out_shape) - left_bound = [int(out_shape[i]/2) for i in range(dim)] - right_bound = [mask.shape[i] - (out_shape[i] - left_bound[i]) for i in range(dim)] +def get_random_box_from_mask(mask, out_shape, mode = 0): + """ + get a bounding box of a subvolume according to a mask + + mode == 0: The output bounding box should be a sub region of the mask region + mode == 1: The center point of the output bounding box can be ahy where of the mask region + """ + dim = len(out_shape) + left_margin = [int(out_shape[i]/2) for i in range(dim)] + right_margin = [out_shape[i] - left_margin[i] for i in range(dim)] + + if(mode == 0): + bb_mask_min, bb_mask_max = get_ND_bounding_box(mask) + bb_valid_min, bb_valid_max = [], [] + for i in range(dim): + mask_size = bb_mask_max[i] - bb_mask_min[i] + if(mask_size > out_shape[i]): + valid_left = bb_mask_min[i] + left_margin[i] + valid_right = bb_mask_max[i] - right_margin[i] + else: + valid_left = (bb_mask_max[i] - bb_mask_min[i]) // 2 + valid_right = valid_left + 1 + bb_valid_min.append(valid_left) + bb_valid_max.append(valid_right) + + valid_region_shape = [bb_valid_max[i] - bb_valid_min[i] for i in range(dim)] + valid_mask = np.zeros_like(mask) + valid_mask = set_ND_volume_roi_with_bounding_box_range(valid_mask, + bb_valid_min, bb_valid_max, np.ones(valid_region_shape, np.bool), addition = True) + valid_mask = valid_mask * mask + else: + valid_mask = mask + indices = np.where(valid_mask) + voxel_num = len(indices[0]) j = random.randint(0, voxel_num - 1) - bb_c = [int(indexes[i][j]) for i in range(dim)] - bb_c = [max(left_bound[i], bb_c[i]) for i in range(dim)] - bb_c = [min(right_bound[i], bb_c[i]) for i in range(dim)] - bb_min = [bb_c[i] - left_bound[i] for i in range(dim)] + bb_c = [int(indices[i][j]) for i in range(dim)] + bb_min = [max(0, bb_c[i] - left_margin[i]) for i in range(dim)] + mask_shape = np.shape(mask) + bb_min = [min(bb_min[i], mask_shape[i] - out_shape[i]) for i in range(dim)] bb_max = [bb_min[i] + out_shape[i] for i in range(dim)] return bb_min, bb_max @@ -205,7 +233,7 @@ def random_crop_ND_volume_with_mask(volume, out_shape, mask): pad = [(ml[i], mr[i]) for i in range(dim)] pad = tuple(pad) image_pad = np.pad(volume, pad, 'reflect') - mask_pad = np.pad(mask, pad, 'reflect') + mask_pad = np.pad(mask, pad, 'constant') bb_min, bb_max = get_random_box_from_mask(mask_pad, out_shape) # left_margin = [int(out_shape[i]/2) for i in range(dim)] @@ -299,7 +327,7 @@ def convert_label(label, source_list, target_list): label_converted[label_s > 0] = label_t[label_s > 0] return label_converted -def resample_sitk_image_to_given_spacing(image, spacing, order): +def resample_sitk_image_to_given_spacing(image, spacing, order = 3): """ Resample an sitk image objct to a given spacing. @@ -319,26 +347,44 @@ def resample_sitk_image_to_given_spacing(image, spacing, order): out_img.SetDirection(image.GetDirection()) return out_img -def get_image_info(img_names): - space0, space1, slices = [], [], [] +def get_image_info(img_names, output_csv = None): + spacing_list, shape_list = [], [] for img_name in img_names: img_obj = sitk.ReadImage(img_name) img_arr = sitk.GetArrayFromImage(img_obj) spacing = img_obj.GetSpacing() - slices.append(img_arr.shape[0]) - space0.append(spacing[0]) - space1.append(spacing[2]) - print(img_name, spacing, img_arr.shape) - - space0 = np.asarray(space0) - space1 = np.asarray(space1) - slices = np.asarray(slices) - print("intra-slice spacing") - print(space0.min(), space0.max(), space0.mean()) - print("inter-slice spacing") - print(space1.min(), space1.max(), space1.mean()) - print("slice number") - print(slices.min(), slices.max(), slices.mean()) + shape = img_arr.shape + spacing_list.append(spacing) + shape_list.append(shape) + print(img_name, spacing, shape) + spacings = np.asarray(spacing_list) + shapes = np.asarray(shape_list) + spacing_min = spacings.min(axis = 0) + spacing_max = spacings.max(axis = 0) + spacing_median = np.percentile(spacings, 50, axis = 0) + print("spacing min", spacing_min) + print("spacing max", spacing_max) + print("spacing median", spacing_median) + + shape_min = shapes.min(axis = 0) + shape_max = shapes.max(axis = 0) + shape_median = np.percentile(shapes, 50, axis = 0) + print("shape min", shape_min) + print("shape max", shape_max) + print("shape median", shape_median) + + if(output_csv is not None): + img_names_short = [item.split("/")[-1] for item in img_names] + img_names_short.extend(["spacing min", "spacing max", "spacing median", + "shape min", "shape max", "shape median"]) + spacing_list.extend([spacing_min, spacing_max, spacing_median, + shape_min, shape_max, shape_median]) + shape_list.extend(['']* 6) + out_dict = {"img_name": img_names_short, + "spacing": spacing_list, + "shape": shape_list} + df = pd.DataFrame.from_dict(out_dict) + df.to_csv(output_csv, index=False) def get_average_mean_std(data_dir, data_csv): df = pd.read_csv(data_csv)