Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
taigw committed Jan 3, 2024
2 parents f43f409 + a906dfc commit d71452e
Show file tree
Hide file tree
Showing 62 changed files with 4,995 additions and 772 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ dist/*
*egg*/*
*stop*
files.txt
pymic/test/runs/*

# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks
# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks
Expand Down
3 changes: 0 additions & 3 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,5 @@ API
pymic.loss
pymic.net
pymic.net_run
pymic.net_run_nll
pymic.net_run_ssl
pymic.net_run_wsl
pymic.transform
pymic.util
8 changes: 8 additions & 0 deletions docs/source/pymic.net.net2d.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ pymic.net.net2d.unet2d\_dual\_branch module
:undoc-members:
:show-inheritance:

pymic.net.net2d.unet2d\_mcnet module
-------------------------------------------

.. automodule:: pymic.net.net2d.unet2d_mcnet
:members:
:undoc-members:
:show-inheritance:

pymic.net.net2d.unet2d\_nest module
-----------------------------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/pymic.net_run.semi_sup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ pymic.net\_run.semi\_sup.ssl\_cps module
:undoc-members:
:show-inheritance:

pymic.net\_run.semi\_sup.ssl\_mcnet module
----------------------------------------

.. automodule:: pymic.net_run.semi_sup.ssl_mcnet
:members:
:undoc-members:
:show-inheritance:

pymic.net\_run.semi\_sup.ssl\_em module
---------------------------------------

Expand Down
7 changes: 5 additions & 2 deletions pymic/io/h5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from pymic import TaskType

class H5DataSet(Dataset):
class H5DataSet_backup(Dataset):
"""
Dataset for loading images stored in h5 format. It generates
4D tensors with dimention order [C, D, H, W] for 3D images, and
Expand Down Expand Up @@ -39,7 +40,9 @@ def __getitem__(self, idx):
if self.transform:
sample = self.transform(sample)
return sample




class TwoStreamBatchSampler(Sampler):
"""Iterate two sets of indices
Expand Down
24 changes: 15 additions & 9 deletions pymic/io/image_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ def load_nifty_volume_as_4d_array(filename):
spacing = img_obj.GetSpacing()
direction = img_obj.GetDirection()
shape = data_array.shape
if(len(shape) == 4):
assert(shape[3] == 1)
elif(len(shape) == 3):
if(len(shape) == 3):
data_array = np.expand_dims(data_array, axis = 0)
else:
elif(len(shape) > 4 or len(shape) < 3):
raise ValueError("unsupported image dim: {0:}".format(len(shape)))
output = {}
output['data_array'] = data_array
Expand Down Expand Up @@ -81,25 +79,32 @@ def load_image_as_nd_array(image_name):
image_name.endswith(".tif") or image_name.endswith(".png")):
image_dict = load_rgb_image_as_3d_array(image_name)
else:
raise ValueError("unsupported image format")
raise ValueError("unsupported image format: {0:}".format(image_name))
return image_dict

def save_array_as_nifty_volume(data, image_name, reference_name = None):
def save_array_as_nifty_volume(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]):
"""
Save a numpy array as nifty image
:param data: (numpy.ndarray) A numpy array with shape [Depth, Height, Width].
:param image_name: (str) The ouput file name.
:param reference_name: (str) File name of the reference image of which
meta information is used.
:param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided.
"""
img = sitk.GetImageFromArray(data)
if(reference_name is not None):
img_ref = sitk.ReadImage(reference_name)
#img.CopyInformation(img_ref)
img.SetSpacing(img_ref.GetSpacing())
img.SetOrigin(img_ref.GetOrigin())
img.SetDirection(img_ref.GetDirection())
direction0 = img_ref.GetDirection()
direction1 = img.GetDirection()
if(len(direction0) == len(direction1)):
img.SetDirection(direction0)
else:
nifty_spacing = spacing[1:] + spacing[:1]
img.SetSpacing(nifty_spacing)
sitk.WriteImage(img, image_name)

def save_array_as_rgb_image(data, image_name):
Expand All @@ -118,21 +123,22 @@ def save_array_as_rgb_image(data, image_name):
img = Image.fromarray(data)
img.save(image_name)

def save_nd_array_as_image(data, image_name, reference_name = None):
def save_nd_array_as_image(data, image_name, reference_name = None, spacing = [1.0,1.0,1.0]):
"""
Save a 3D or 2D numpy array as medical image or RGB image
:param data: (numpy.ndarray) A numpy array with shape [3, H, W] or
[H, W, 3] or [H, W].
:param reference_name: (str) File name of the reference image of which
meta information is used.
:param spacing: (list or tuple) the spacing of a volume data when `reference_name` is not provided.
"""
data_dim = len(data.shape)
assert(data_dim == 2 or data_dim == 3)
if (image_name.endswith(".nii.gz") or image_name.endswith(".nii") or
image_name.endswith(".mha")):
assert(data_dim == 3)
save_array_as_nifty_volume(data, image_name, reference_name)
save_array_as_nifty_volume(data, image_name, reference_name, spacing)

elif(image_name.endswith(".jpg") or image_name.endswith(".jpeg") or
image_name.endswith(".tif") or image_name.endswith(".png")):
Expand Down
42 changes: 31 additions & 11 deletions pymic/io/nifty_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@

import logging
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data import Dataset
from pymic import TaskType
from pymic.io.image_read_write import load_image_as_nd_array

Expand Down Expand Up @@ -38,7 +36,8 @@ def __init__(self, root_dir, csv_file, modal_num = 1,
if('label' not in csv_keys):
logging.warning("`label` section is not found in the csv file {0:}".format(
csv_file) + "\n -- This is only allowed for self-supervised learning" +
"\n -- when `SelfSuperviseLabel` is used in the transform.")
"\n -- when `SelfSuperviseLabel` is used in the transform, or when" +
"\n -- loading the unlabeled data for preprocessing.")
self.with_label = False
self.image_weight_idx = None
self.pixel_weight_idx = None
Expand All @@ -52,15 +51,15 @@ def __len__(self):

def __getlabel__(self, idx):
csv_keys = list(self.csv_items.keys())
label_idx = csv_keys.index('label')
label_name = "{0:}/{1:}".format(self.root_dir,
self.csv_items.iloc[idx, label_idx])
label = load_image_as_nd_array(label_name)['data_array']
label_idx = csv_keys.index('label')
label_name = self.csv_items.iloc[idx, label_idx]
label_name_full = "{0:}/{1:}".format(self.root_dir, label_name)
label = load_image_as_nd_array(label_name_full)['data_array']
if(self.task == TaskType.SEGMENTATION):
label = np.asarray(label, np.int32)
elif(self.task == TaskType.RECONSTRUCTION):
label = np.asarray(label, np.float32)
return label
return label, label_name

def __get_pixel_weight__(self, idx):
weight_name = "{0:}/{1:}".format(self.root_dir,
Expand All @@ -69,6 +68,25 @@ def __get_pixel_weight__(self, idx):
weight = np.asarray(weight, np.float32)
return weight

# def __getitem__(self, idx):
# sample_name = self.csv_items.iloc[idx, 0]
# h5f = h5py.File(self.root_dir + '/' + sample_name, 'r')
# image = np.asarray(h5f['image'][:], np.float32)

# # this a temporaory process, will be delieted later
# if(len(image.shape) == 3 and image.shape[0] > 1):
# image = np.expand_dims(image, 0)
# sample = {'image': image, 'names':sample_name}

# if('label' in h5f):
# label = np.asarray(h5f['label'][:], np.uint8)
# if(len(label.shape) == 3 and label.shape[0] > 1):
# label = np.expand_dims(label, 0)
# sample['label'] = label
# if self.transform:
# sample = self.transform(sample)
# return sample

def __getitem__(self, idx):
names_list, image_list = [], []
for i in range (self.modal_num):
Expand All @@ -80,12 +98,14 @@ def __getitem__(self, idx):
image_list.append(image_data)
image = np.concatenate(image_list, axis = 0)
image = np.asarray(image, np.float32)
sample = {'image': image, 'names' : names_list[0],

sample = {'image': image, 'names' : names_list,
'origin':image_dict['origin'],
'spacing': image_dict['spacing'],
'direction':image_dict['direction']}
if (self.with_label):
sample['label'] = self.__getlabel__(idx)
sample['label'], label_name = self.__getlabel__(idx)
sample['names'].append(label_name)
assert(image.shape[1:] == sample['label'].shape[1:])
if (self.image_weight_idx is not None):
sample['image_weight'] = self.csv_items.iloc[idx, self.image_weight_idx]
Expand Down
15 changes: 13 additions & 2 deletions pymic/loss/seg/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@ class AbstractSegLoss(nn.Module):
def __init__(self, params = None):
super(AbstractSegLoss, self).__init__()
if(params is None):
self.softmax = True
self.acti_func = 'softmax'
else:
self.softmax = params.get('loss_softmax', True)
self.acti_func = params.get('loss_acti_func', 'softmax')

def get_activated_prediction(self, p, acti_func = 'softmax'):
if(acti_func == "softmax"):
p = nn.Softmax(dim = 1)(p)
elif(acti_func == "tanh"):
p = nn.Tanh()(p)
elif(acti_func == "sigmoid"):
p = nn.Sigmoid()(p)
else:
raise ValueError("activation for output is not supported: {0:}".format(acti_func))
return p

def forward(self, loss_input_dict):
"""
Expand Down
15 changes: 9 additions & 6 deletions pymic/loss/seg/ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ class CrossEntropyLoss(AbstractSegLoss):
The parameters should be written in the `params` dictionary, and it has the
following fields:
:param `loss_softmax`: (optional, bool)
Apply softmax to the prediction of network or not. Default is True.
:param `loss_acti_func`: (optional, string)
Apply an activation function to the prediction of network or not, for example,
'softmax' for image segmentation tasks, 'tanh' for reconstruction tasks, and None
means no activation is used.
"""
def __init__(self, params = None):
super(CrossEntropyLoss, self).__init__(params)
Expand All @@ -27,8 +29,9 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)

predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)

Expand Down Expand Up @@ -74,8 +77,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)
gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y
Expand Down
20 changes: 10 additions & 10 deletions pymic/loss/seg/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)
if(pix_w is not None):
Expand All @@ -52,8 +52,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = 1.0 - predict[:, :1, :, :, :]
soft_y = 1.0 - soft_y[:, :1, :, :, :]
predict = reshape_tensor_to_2D(predict)
Expand All @@ -76,8 +76,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)
num_class = list(predict.size())[1]
Expand Down Expand Up @@ -115,8 +115,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)

Expand Down Expand Up @@ -149,8 +149,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)

Expand Down
4 changes: 2 additions & 2 deletions pymic/loss/seg/exp_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
predict = reshape_tensor_to_2D(predict)
soft_y = reshape_tensor_to_2D(soft_y)

Expand Down
8 changes: 4 additions & 4 deletions pymic/loss/seg/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
mse = torch.square(predict - soft_y)
mse = torch.mean(mse)
return mse
Expand All @@ -44,8 +44,8 @@ def forward(self, loss_input_dict):

if(isinstance(predict, (list, tuple))):
predict = predict[0]
if(self.softmax):
predict = nn.Softmax(dim = 1)(predict)
if(self.acti_func is not None):
predict = self.get_activated_prediction(predict, self.acti_func)
mae = torch.abs(predict - soft_y)
if(weight is None):
mae = torch.mean(mae)
Expand Down
Loading

0 comments on commit d71452e

Please sign in to comment.