From 9b2b072191a2abf1b8dc0f0db6828ff47d2dc1c6 Mon Sep 17 00:00:00 2001 From: Ebrahim Ebrahim Date: Wed, 15 May 2024 11:29:19 -0400 Subject: [PATCH] make automated style fixes using ruff --- LICENSE | 2 +- README.md | 173 +++++++------ src/hd_bet/config.py | 51 ++-- src/hd_bet/data_loading.py | 98 ++++--- src/hd_bet/hd_bet_cli.py | 186 ++++++++++---- src/hd_bet/network_architecture.py | 397 ++++++++++++++++++++++------- src/hd_bet/paths.py | 4 +- src/hd_bet/predict_case.py | 66 +++-- src/hd_bet/run.py | 76 ++++-- src/hd_bet/utils.py | 59 +++-- 10 files changed, 789 insertions(+), 323 deletions(-) diff --git a/LICENSE b/LICENSE index 9c8f3ea..8dada3e 100644 --- a/LICENSE +++ b/LICENSE @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/README.md b/README.md index d3313c5..1969737 100644 --- a/README.md +++ b/README.md @@ -26,105 +26,128 @@ -This repository provides easy to use access to our recently published HD-BET brain extraction tool. HD-BET is the result -of a joint project between the Department of Neuroradiology at the Heidelberg University Hospital and the +This repository provides easy to use access to our recently published HD-BET +brain extraction tool. HD-BET is the result of a joint project between the +Department of Neuroradiology at the Heidelberg University Hospital and the Division of Medical Image Computing at the German Cancer Research Center (DKFZ). -If you are using HD-BET, please cite the following publication: +If you are using HD-BET, please cite the following publication: -Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W, -Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial neural +Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, +Schlemmer HP, Heiland S, Wick W, Bendszus M, Maier-Hein KH, Kickingereder P. +Automated brain extraction of multi-sequence MRI using artificial neural networks. Hum Brain Mapp. 2019; 1–13. https://doi.org/10.1002/hbm.24750 -Compared to other commonly used brain extraction tools, HD-BET has some significant advantages: -- HD-BET was developed with MRI-data from a large multicentric clinical trial in adult brain tumor patients acquired -across 37 institutions in Europe and included a broad range of MR hardware and acquisition parameters, pathologies -or treatment-induced tissue alterations. We used 2/3 of data for training and validation and 1/3 for testing. -Moreover independent testing of HD-BET was performed in three public benchmark datasets (NFBS, LPBA40 and CC-359). -- HD-BET was trained with precontrast T1-w, postcontrast T1-w, T2-w and FLAIR sequences. It can perform independent -brain extraction on various different MRI sequences and is not restricted to precontrast T1-weighted (T1-w) sequences. - Other MRI sequences may work as well (just give it a try!) -- HD-BET was designed to be robust with respect to brain tumors, lesions and resection cavities as well as different -MRI scanner hardware and acquisition parameters. -- HD-BET outperformed five publicly available brain extraction algorithms (FSL BET, AFNI 3DSkullStrip, Brainsuite BSE, -ROBEX and BEaST) across all datasets and yielded median improvements of +1.33 to +2.63 points for the DICE -coefficient and -0.80 to -2.75 mm for the Hausdorff distance (Bonferroni-adjusted p<0.001). -- HD-BET is very fast on GPU with <10s run time per MRI sequence. Even on CPU it is not slower than other commonly -used tools. - -## Installation Instructions -Note that you need to have a python3 installation for HD-BET to work. Please also make sure to install HD-BET with the -correct pip version (the one that is connected to python3). You can verify this using the `--version` command: +Compared to other commonly used brain extraction tools, HD-BET has some +significant advantages: + +- HD-BET was developed with MRI-data from a large multicentric clinical trial in + adult brain tumor patients acquired across 37 institutions in Europe and + included a broad range of MR hardware and acquisition parameters, pathologies + or treatment-induced tissue alterations. We used 2/3 of data for training and + validation and 1/3 for testing. Moreover independent testing of HD-BET was + performed in three public benchmark datasets (NFBS, LPBA40 and CC-359). +- HD-BET was trained with precontrast T1-w, postcontrast T1-w, T2-w and FLAIR + sequences. It can perform independent brain extraction on various different + MRI sequences and is not restricted to precontrast T1-weighted (T1-w) + sequences. Other MRI sequences may work as well (just give it a try!) +- HD-BET was designed to be robust with respect to brain tumors, lesions and + resection cavities as well as different MRI scanner hardware and acquisition + parameters. +- HD-BET outperformed five publicly available brain extraction algorithms (FSL + BET, AFNI 3DSkullStrip, Brainsuite BSE, ROBEX and BEaST) across all datasets + and yielded median improvements of +1.33 to +2.63 points for the DICE + coefficient and -0.80 to -2.75 mm for the Hausdorff distance + (Bonferroni-adjusted p<0.001). +- HD-BET is very fast on GPU with <10s run time per MRI sequence. Even on CPU it + is not slower than other commonly used tools. + +## Installation Instructions + +Note that you need to have a python3 installation for HD-BET to work. Please +also make sure to install HD-BET with the correct pip version (the one that is +connected to python3). You can verify this using the `--version` command: ``` (dl_venv) fabian@Fabian:~$ pip --version pip 20.0.2 from /home/fabian/dl_venv/lib/python3.6/site-packages/pip (python 3.6) ``` -If it does not show python 3.X, you can try pip3. If that also does not work you probably need to install python3 first. - -Once python 3 and pip are set up correctly, run the following commands to install HD-BET: -1) Clone this repository: - ```bash - git clone https://github.com/MIC-DKFZ/HD-BET - ``` -2) Go into the repository (the folder with the setup.py file) and install: - ``` - cd HD-BET - pip install -e . - ``` -3) Per default, model parameters will be downloaded to ~/hd-bet_params. If you wish to use a different folder, open -HD_BET/paths.py in a text editor and modify ```folder_with_parameter_files``` - - -## How to use it - -Using HD_BET is straightforward. You can use it in any terminal on your linux system. The hd-bet command was installed -automatically. We provide CPU as well as GPU support. Running on GPU is a lot faster though -and should always be preferred. Here is a minimalistic example of how you can use HD-BET (you need to be in the HD_BET -directory) +If it does not show python 3.X, you can try pip3. If that also does not work you +probably need to install python3 first. + +Once python 3 and pip are set up correctly, run the following commands to +install HD-BET: + +1. Clone this repository: + ```bash + git clone https://github.com/MIC-DKFZ/HD-BET + ``` +2. Go into the repository (the folder with the setup.py file) and install: + ``` + cd HD-BET + pip install -e . + ``` +3. Per default, model parameters will be downloaded to ~/hd-bet_params. If you + wish to use a different folder, open HD_BET/paths.py in a text editor and + modify `folder_with_parameter_files` + +## How to use it + +Using HD_BET is straightforward. You can use it in any terminal on your linux +system. The hd-bet command was installed automatically. We provide CPU as well +as GPU support. Running on GPU is a lot faster though and should always be +preferred. Here is a minimalistic example of how you can use HD-BET (you need to +be in the HD_BET directory) ```bash hd-bet -i INPUT_FILENAME ``` -INPUT_FILENAME must be a nifti (.nii.gz) file containing 3D MRI image data. 4D image sequences are not supported -(however can be splitted upfront into the individual temporal volumes using fslsplit1). -INPUT_FILENAME can be either a pre- or postcontrast T1-w, T2-w or FLAIR MRI sequence. Other modalities might work as well. -Input images must match the orientation of standard MNI152 template! Use fslreorient2std 2 upfront to ensure -that this is the case. +INPUT_FILENAME must be a nifti (.nii.gz) file containing 3D MRI image data. 4D +image sequences are not supported (however can be splitted upfront into the +individual temporal volumes using fslsplit1). INPUT_FILENAME can be +either a pre- or postcontrast T1-w, T2-w or FLAIR MRI sequence. Other modalities +might work as well. Input images must match the orientation of standard MNI152 +template! Use fslreorient2std 2 upfront to ensure that this is the +case. -By default, HD-BET will run in GPU mode, use the parameters of all five models (which originate from a five-fold -cross-validation), use test time data augmentation by mirroring along all axes and not do any postprocessing. +By default, HD-BET will run in GPU mode, use the parameters of all five models +(which originate from a five-fold cross-validation), use test time data +augmentation by mirroring along all axes and not do any postprocessing. -For batch processing it is faster to process an entire folder at once as this will mitigate the overhead of loading -and initializing the model for each case: +For batch processing it is faster to process an entire folder at once as this +will mitigate the overhead of loading and initializing the model for each case: ```bash hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER ``` -The above command will look for all nifti files (*.nii.gz) in the INPUT_FOLDER and save the brain masks under the same name -in OUTPUT_FOLDER. +The above command will look for all nifti files (\*.nii.gz) in the INPUT_FOLDER +and save the brain masks under the same name in OUTPUT_FOLDER. -### GPU is nice, but I don't have one of those... What now? +### GPU is nice, but I don't have one of those... What now? -HD-BET has CPU support. Running on CPU takes a lot longer though and you will need quite a bit of RAM. To run on CPU, -we recommend you use the following command: +HD-BET has CPU support. Running on CPU takes a lot longer though and you will +need quite a bit of RAM. To run on CPU, we recommend you use the following +command: ```bash hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER -device cpu -mode fast -tta 0 ``` + This works of course also with just an input file: ```bash hd-bet -i INPUT_FILENAME -device cpu -mode fast -tta 0 ``` -The options *-mode fast* and *-tta 0* will disable test time data augmentation (speedup of 8x) and use only one model instead of an ensemble of five models -for the prediction. +The options _-mode fast_ and _-tta 0_ will disable test time data augmentation +(speedup of 8x) and use only one model instead of an ensemble of five models for +the prediction. ### More options: + For more information, please refer to the help functionality: ```bash @@ -133,18 +156,20 @@ hd-bet --help ## FAQ -1) **How much GPU memory do I need to run HD-BET?** -We ran all our experiments on NVIDIA Titan X GPUs with 12 GB memory. For inference you will need less, but since -inference in implemented by exploiting the fully convolutional nature of CNNs the amount of memory required depends on -your image. Typical image should run with less than 4 GB of GPU memory consumption. If you run into out of memory -problems please check the following: 1) Make sure the voxel spacing of your data is correct and 2) Ensure your MRI -image only contains the head region -2) **Will you provide the training code as well?** -No. The training code is tightly wound around the data which we cannot make public. -3) **What run time can I expect on CPU/GPU?** -This depends on your MRI image size. Typical run times (preprocessing, postprocessing and resampling included) are just - a couple of seconds for GPU and about 2 Minutes on CPU (using ```-tta 0 -mode fast```) - +1. **How much GPU memory do I need to run HD-BET?** We ran all our experiments + on NVIDIA Titan X GPUs with 12 GB memory. For inference you will need less, + but since inference in implemented by exploiting the fully convolutional + nature of CNNs the amount of memory required depends on your image. Typical + image should run with less than 4 GB of GPU memory consumption. If you run + into out of memory problems please check the following: 1) Make sure the + voxel spacing of your data is correct and 2) Ensure your MRI image only + contains the head region +2. **Will you provide the training code as well?** No. The training code is + tightly wound around the data which we cannot make public. +3. **What run time can I expect on CPU/GPU?** This depends on your MRI image + size. Typical run times (preprocessing, postprocessing and resampling + included) are just a couple of seconds for GPU and about 2 Minutes on CPU + (using `-tta 0 -mode fast`) 1https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Fslutils diff --git a/src/hd_bet/config.py b/src/hd_bet/config.py index 870951e..ab4d41f 100755 --- a/src/hd_bet/config.py +++ b/src/hd_bet/config.py @@ -1,11 +1,15 @@ +from __future__ import annotations + +from abc import abstractmethod + import numpy as np import torch -from HD_BET.utils import SetNetworkToVal, softmax_helper -from abc import abstractmethod -from HD_BET.network_architecture import Network + +from hd_bet.network_architecture import Network +from hd_bet.utils import SetNetworkToVal, softmax_helper -class BaseConfig(object): +class BaseConfig: def __init__(self): pass @@ -31,8 +35,8 @@ def preprocess(self, data): def __repr__(self): res = "" for v in vars(self): - if not v.startswith("__") and not v.startswith("_") and v != 'dataset': - res += (v + ": " + str(self.__getattribute__(v)) + "\n") + if not v.startswith("__") and not v.startswith("_") and v != "dataset": + res += v + ": " + str(self.__getattribute__(v)) + "\n" return res @@ -40,7 +44,7 @@ class HD_BET_Config(BaseConfig): def __init__(self): super(HD_BET_Config, self).__init__() - self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name + self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name # network parameters self.net_base_num_layers = 21 @@ -62,13 +66,15 @@ def __init__(self): # validation self.val_use_DO = False - self.val_use_train_mode = False # for dropout sampling - self.val_num_repeats = 1 # only useful if dropout sampling - self.val_batch_size = 1 # only useful if dropout sampling + self.val_use_train_mode = False # for dropout sampling + self.val_num_repeats = 1 # only useful if dropout sampling + self.val_batch_size = 1 # only useful if dropout sampling self.val_save_npz = True - self.val_do_mirroring = True # test time data augmentation via mirroring + self.val_do_mirroring = True # test time data augmentation via mirroring self.val_write_images = True - self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property + self.net_input_must_be_divisible_by = ( + 16 # we could make a network class that has this as a property + ) self.val_min_size = self.INPUT_PATCH_SIZE self.val_fn = None @@ -78,13 +84,25 @@ def __init__(self): self.val_use_moving_averages = False def get_network(self, train=True, pretrained_weights=None): - net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, - self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, - self.net_norm_use_affine, True, self.net_do_DS) + net = Network( + self.num_classes, + len(self.selected_data_channels), + self.net_base_num_layers, + self.net_dropout_p, + softmax_helper, + self.net_leaky_relu_slope, + self.net_conv_use_bias, + self.net_norm_use_affine, + True, + self.net_do_DS, + ) if pretrained_weights is not None: net.load_state_dict( - torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) + torch.load( + pretrained_weights, map_location=lambda storage, loc: storage + ) + ) if train: net.train(True) @@ -118,4 +136,3 @@ def preprocess(self, data): config = HD_BET_Config - diff --git a/src/hd_bet/data_loading.py b/src/hd_bet/data_loading.py index 0a953ba..a617b3c 100755 --- a/src/hd_bet/data_loading.py +++ b/src/hd_bet/data_loading.py @@ -1,20 +1,28 @@ -import SimpleITK as sitk +from __future__ import annotations + import numpy as np +import SimpleITK as sitk from skimage.transform import resize def resize_image(image, old_spacing, new_spacing, order=3): - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + return resize( + image, new_shape, order=order, mode="edge", cval=0, anti_aliasing=False + ) def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] image = sitk.GetArrayFromImage(itk_image).astype(float) - assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" + assert ( + len(image.shape) == 3 + ), "The image has unsupported number of dimensions. Only 3D images are allowed" if not is_seg: if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): @@ -23,9 +31,11 @@ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): image -= image.mean() image /= image.std() else: - new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), - int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), - int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) + new_shape = ( + int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), + int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), + int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))), + ) image = resize_segmentation(image, new_shape, 1) return image @@ -39,16 +49,18 @@ def load_and_preprocess(mri_file): "spacing": images["T1"].GetSpacing(), "direction": images["T1"].GetDirection(), "size": images["T1"].GetSize(), - "origin": images["T1"].GetOrigin() + "origin": images["T1"].GetOrigin(), } - for k in images.keys(): - images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) + for k in images: + images[k] = preprocess_image( + images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5) + ) - properties_dict['size_before_cropping'] = images["T1"].shape + properties_dict["size_before_cropping"] = images["T1"].shape imgs = [] - for seq in ['T1']: + for seq in ["T1"]: imgs.append(images[seq][None]) all_data = np.vstack(imgs) print("image shape after preprocessing: ", str(all_data[0].shape)) @@ -56,7 +68,7 @@ def load_and_preprocess(mri_file): def save_segmentation_nifti(segmentation, dct, out_fname, order=1, dtype=np.uint8): - ''' + """ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out of the original image @@ -72,31 +84,38 @@ def save_segmentation_nifti(segmentation, dct, out_fname, order=1, dtype=np.uint :param dct: :param out_fname: :return: - ''' - old_size = dct.get('size_before_cropping') - bbox = dct.get('brain_bbox') + """ + old_size = dct.get("size_before_cropping") + bbox = dct.get("brain_bbox") if bbox is not None: seg_old_size = np.zeros(old_size) for c in range(3): bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) - seg_old_size[bbox[0][0]:bbox[0][1], - bbox[1][0]:bbox[1][1], - bbox[2][0]:bbox[2][1]] = segmentation + seg_old_size[ + bbox[0][0] : bbox[0][1], bbox[1][0] : bbox[1][1], bbox[2][0] : bbox[2][1] + ] = segmentation else: seg_old_size = segmentation - if np.any([i != j for i, j in zip(np.array(seg_old_size), np.array(dct['size'])[[2, 1, 0]])]): - seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) + if np.any( + [ + i != j + for i, j in zip(np.array(seg_old_size), np.array(dct["size"])[[2, 1, 0]]) + ] + ): + seg_old_spacing = resize_segmentation( + seg_old_size, np.array(dct["size"])[[2, 1, 0]], order=order + ) else: seg_old_spacing = seg_old_size seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(dtype)) - seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) - seg_resized_itk.SetOrigin(dct['origin']) - seg_resized_itk.SetDirection(dct['direction']) + seg_resized_itk.SetSpacing(np.array(dct["spacing"])[[0, 1, 2]]) + seg_resized_itk.SetOrigin(dct["origin"]) + seg_resized_itk.SetDirection(dct["direction"]) sitk.WriteImage(seg_resized_itk, out_fname) def resize_segmentation(segmentation, new_shape, order=3, cval=0): - ''' + """ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one @@ -106,16 +125,33 @@ def resize_segmentation(segmentation, new_shape, order=3, cval=0): :param new_shape: :param order: :return: - ''' + """ tpe = segmentation.dtype unique_labels = np.unique(segmentation) - assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + assert len(segmentation.shape) == len( + new_shape + ), "new shape must have same dimensionality as segmentation" if order == 0: - return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) + return resize( + segmentation, + new_shape, + order, + mode="constant", + cval=cval, + clip=True, + anti_aliasing=False, + ).astype(tpe) else: reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for i, c in enumerate(unique_labels): - reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped_multihot = resize( + (segmentation == c).astype(float), + new_shape, + order, + mode="edge", + clip=True, + anti_aliasing=False, + ) reshaped[reshaped_multihot >= 0.5] = c return reshaped diff --git a/src/hd_bet/hd_bet_cli.py b/src/hd_bet/hd_bet_cli.py index bfd79fa..0748049 100755 --- a/src/hd_bet/hd_bet_cli.py +++ b/src/hd_bet/hd_bet_cli.py @@ -1,52 +1,106 @@ -#!/usr/bin/env python +from __future__ import annotations import os -from HD_BET.run import run_hd_bet -from HD_BET.utils import maybe_mkdir_p, subfiles -import HD_BET +import hd_bet +from hd_bet.run import run_hd_bet +from hd_bet.utils import maybe_mkdir_p, subfiles -if __name__ == "__main__": + +def main(): print("\n########################") print("If you are using hd-bet, please cite the following paper:") - print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," - "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" - "neural networks. arXiv preprint arXiv:1901.11341, 2019.") + print( + "Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," + "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" + "neural networks. arXiv preprint arXiv:1901.11341, 2019." + ) print("########################\n") import argparse + parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be ' - 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to ' - 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz ' - 'within that folder will be brain extracted.', required=True, type=str) - parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder' - ' will be created', required=False, type=str) - parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will ' - 'use only one set of parameters whereas accurate will ' - 'use the five sets of parameters that resulted from ' - 'our cross-validation as an ensemble. Default: ' - 'accurate', - required=False) - parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. ' - 'Must be either int or str. Use int for GPU id or ' - '\'cpu\' to run on CPU. When using CPU you should ' - 'consider disabling tta. Default for -device is: 0', - required=False) - parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' - '(mirroring). 1= True, 0=False. Disable this ' - 'if you are using CPU to speed things up! ' - 'Default: 1') - parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all' - ' but the largest connected component in ' - 'the prediction. Default: 1') - parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation ' - 'mask will not be ' - 'saved') - parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't " - "want to overwrite existing " - "predictions") - parser.add_argument('-b','--bet', default=1, type=int, required=False, help="set this to 0 if you don't want to save skull-stripped brain") + parser.add_argument( + "-i", + "--input", + help="input. Can be either a single file name or an input folder. If file: must be " + "nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to " + "split 4d sequences into 3d images. If folder: all files ending with .nii.gz " + "within that folder will be brain extracted.", + required=True, + type=str, + ) + parser.add_argument( + "-o", + "--output", + help="output. Can be either a filename or a folder. If it does not exist, the folder" + " will be created", + required=False, + type=str, + ) + parser.add_argument( + "-mode", + type=str, + default="accurate", + help="can be either 'fast' or 'accurate'. Fast will " + "use only one set of parameters whereas accurate will " + "use the five sets of parameters that resulted from " + "our cross-validation as an ensemble. Default: " + "accurate", + required=False, + ) + parser.add_argument( + "-device", + default="0", + type=str, + help="used to set on which device the prediction will run. " + "Must be either int or str. Use int for GPU id or " + "'cpu' to run on CPU. When using CPU you should " + "consider disabling tta. Default for -device is: 0", + required=False, + ) + parser.add_argument( + "-tta", + default=1, + required=False, + type=int, + help="whether to use test time data augmentation " + "(mirroring). 1= True, 0=False. Disable this " + "if you are using CPU to speed things up! " + "Default: 1", + ) + parser.add_argument( + "-pp", + default=1, + type=int, + required=False, + help="set to 0 to disabe postprocessing (remove all" + " but the largest connected component in " + "the prediction. Default: 1", + ) + parser.add_argument( + "-s", + "--save_mask", + default=1, + type=int, + required=False, + help="if set to 0 the segmentation " "mask will not be " "saved", + ) + parser.add_argument( + "--overwrite_existing", + default=1, + type=int, + required=False, + help="set this to 0 if you don't " "want to overwrite existing " "predictions", + ) + parser.add_argument( + "-b", + "--bet", + default=1, + type=int, + required=False, + help="set this to 0 if you don't want to save skull-stripped brain", + ) args = parser.parse_args() @@ -54,8 +108,10 @@ output_file_or_dir = args.output if output_file_or_dir is None: - output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir), - os.path.basename(input_file_or_dir).split(".")[0] + "_bet") + output_file_or_dir = os.path.join( + os.path.dirname(input_file_or_dir), + os.path.basename(input_file_or_dir).split(".")[0] + "_bet", + ) mode = args.mode device = args.device @@ -65,29 +121,35 @@ overwrite_existing = args.overwrite_existing bet = args.bet - params_file = os.path.join(HD_BET.__path__[0], "model_final.py") - config_file = os.path.join(HD_BET.__path__[0], "config.py") + params_file = os.path.join(hd_bet.__path__[0], "model_final.py") + config_file = os.path.join(hd_bet.__path__[0], "config.py") - assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + assert os.path.abspath(input_file_or_dir) != os.path.abspath( + output_file_or_dir + ), "output must be different from input" - if device == 'cpu': + if device == "cpu": pass else: device = int(device) if os.path.isdir(input_file_or_dir): maybe_mkdir_p(output_file_or_dir) - input_files = subfiles(input_file_or_dir, suffix='.nii.gz', join=False) + input_files = subfiles(input_file_or_dir, suffix=".nii.gz", join=False) if len(input_files) == 0: - raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here") + raise RuntimeError( + "input is a folder but no nifti files (.nii.gz) were found in here" + ) output_files = [os.path.join(output_file_or_dir, i) for i in input_files] input_files = [os.path.join(input_file_or_dir, i) for i in input_files] else: - if not output_file_or_dir.endswith('.nii.gz'): - output_file_or_dir += '.nii.gz' - assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" + if not output_file_or_dir.endswith(".nii.gz"): + output_file_or_dir += ".nii.gz" + assert os.path.abspath(input_file_or_dir) != os.path.abspath( + output_file_or_dir + ), "output must be different from input" output_files = [output_file_or_dir] input_files = [input_file_or_dir] @@ -104,7 +166,10 @@ elif overwrite_existing == 1: overwrite_existing = True else: - raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing)) + raise ValueError( + "Unknown value for overwrite_existing: %s. Expected: 0 or 1" + % str(overwrite_existing) + ) if pp == 0: pp = False @@ -118,7 +183,9 @@ elif save_mask == 1: save_mask = True else: - raise ValueError("Unknown value for save_mask: %s. Expected: 0 or 1" % str(save_mask)) + raise ValueError( + "Unknown value for save_mask: %s. Expected: 0 or 1" % str(save_mask) + ) if bet == 0: if save_mask: @@ -130,5 +197,16 @@ bet = True else: raise ValueError("Unknown value for bet: %s. Expected: 0 or 1" % str(pp)) - - run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing, bet) + + run_hd_bet( + input_files, + output_files, + mode, + config_file, + device, + pp, + tta, + save_mask, + overwrite_existing, + bet, + ) diff --git a/src/hd_bet/network_architecture.py b/src/hd_bet/network_architecture.py index 0824aa1..971a42b 100755 --- a/src/hd_bet/network_architecture.py +++ b/src/hd_bet/network_architecture.py @@ -1,38 +1,74 @@ +from __future__ import annotations + import torch -import torch.nn as nn import torch.nn.functional as F -from HD_BET.utils import softmax_helper +from torch import nn + +from hd_bet.utils import softmax_helper class EncodingModule(nn.Module): - def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True): + def __init__( + self, + in_channels, + out_channels, + filter_size=3, + dropout_p=0.3, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): nn.Module.__init__(self) self.dropout_p = dropout_p self.lrelu_inplace = lrelu_inplace self.inst_norm_affine = inst_norm_affine self.conv_bias = conv_bias self.leakiness = leakiness - self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + self.bn_1 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + filter_size, + 1, + (filter_size - 1) // 2, + bias=self.conv_bias, + ) self.dropout = nn.Dropout3d(dropout_p) - self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) + self.bn_2 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv2 = nn.Conv3d( + out_channels, + out_channels, + filter_size, + 1, + (filter_size - 1) // 2, + bias=self.conv_bias, + ) def forward(self, x): skip = x - x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv1(x) if self.dropout_p is not None and self.dropout_p > 0: x = self.dropout(x) - x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) x = self.conv2(x) x = x + skip return x class Upsample(nn.Module): - def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=True + ): super(Upsample, self).__init__() self.align_corners = align_corners self.mode = mode @@ -40,68 +76,129 @@ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=T self.size = size def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, - align_corners=self.align_corners) + return nn.functional.interpolate( + x, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) class LocalizationModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): nn.Module.__init__(self) self.lrelu_inplace = lrelu_inplace self.inst_norm_affine = inst_norm_affine self.conv_bias = conv_bias self.leakiness = leakiness self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) - self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.bn_1 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) - self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.bn_2 = nn.InstanceNorm3d( + out_channels, affine=self.inst_norm_affine, track_running_stats=True + ) def forward(self, x): - x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn_1(self.conv1(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = F.leaky_relu( + self.bn_2(self.conv2(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) return x class UpsamplingModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): nn.Module.__init__(self) self.lrelu_inplace = lrelu_inplace self.inst_norm_affine = inst_norm_affine self.conv_bias = conv_bias self.leakiness = leakiness self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) - self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) - self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) + self.upsample_conv = nn.Conv3d( + in_channels, out_channels, 3, 1, 1, bias=self.conv_bias + ) + self.bn = nn.InstanceNorm3d( + out_channels, affine=self.inst_norm_affine, track_running_stats=True + ) def forward(self, x): - x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn(self.upsample_conv(self.upsample(x))), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) return x class DownsamplingModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): nn.Module.__init__(self) self.lrelu_inplace = lrelu_inplace self.inst_norm_affine = inst_norm_affine self.conv_bias = conv_bias self.leakiness = leakiness - self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) + self.bn = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.downsample = nn.Conv3d( + in_channels, out_channels, 3, 2, 1, bias=self.conv_bias + ) def forward(self, x): - x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) b = self.downsample(x) return x, b class Network(nn.Module): - def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, - final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True, do_ds=True): + def __init__( + self, + num_classes=4, + num_input_channels=4, + base_filters=16, + dropout_p=0.3, + final_nonlin=softmax_helper, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + do_ds=True, + ): super(Network, self).__init__() self.do_ds = do_ds @@ -110,56 +207,174 @@ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout self.conv_bias = conv_bias self.leakiness = leakiness self.final_nonlin = final_nonlin - self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) - - self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, - conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) - - self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) - self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) + self.init_conv = nn.Conv3d( + num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias + ) + + self.context1 = EncodingModule( + base_filters, + base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down1 = DownsamplingModule( + base_filters, + base_filters * 2, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context2 = EncodingModule( + 2 * base_filters, + 2 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down2 = DownsamplingModule( + 2 * base_filters, + base_filters * 4, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context3 = EncodingModule( + 4 * base_filters, + 4 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down3 = DownsamplingModule( + 4 * base_filters, + base_filters * 8, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context4 = EncodingModule( + 8 * base_filters, + 8 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down4 = DownsamplingModule( + 8 * base_filters, + base_filters * 16, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context5 = EncodingModule( + 16 * base_filters, + 16 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.bn_after_context5 = nn.InstanceNorm3d( + 16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) + self.up1 = UpsamplingModule( + 16 * base_filters, + 8 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc1 = LocalizationModule( + 16 * base_filters, + 8 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.up2 = UpsamplingModule( + 8 * base_filters, + 4 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc2 = LocalizationModule( + 8 * base_filters, + 4 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) - self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) + self.up3 = UpsamplingModule( + 4 * base_filters, + 2 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc3 = LocalizationModule( + 4 * base_filters, + 2 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) - self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) - self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) - self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) - self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) + self.up4 = UpsamplingModule( + 2 * base_filters, + 1 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.end_conv_1 = nn.Conv3d( + 2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias + ) + self.end_conv_1_bn = nn.InstanceNorm3d( + 2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) + self.end_conv_2 = nn.Conv3d( + 2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias + ) + self.end_conv_2_bn = nn.InstanceNorm3d( + 2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) def forward(self, x): @@ -180,7 +395,11 @@ def forward(self, x): skip4, x = self.down4(x) x = self.context5(x) - x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.bn_after_context5(x), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) x = self.up1(x) x = torch.cat((skip4, x), dim=1) @@ -200,10 +419,16 @@ def forward(self, x): x = self.up4(x) x = torch.cat((skip1, x), dim=1) - x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) - x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) + x = F.leaky_relu( + self.end_conv_1_bn(self.end_conv_1(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = F.leaky_relu( + self.end_conv_2_bn(self.end_conv_2(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) x = self.final_nonlin(self.seg_layer(x)) seg_outputs.append(x) diff --git a/src/hd_bet/paths.py b/src/hd_bet/paths.py index 13b2e65..e93a734 100644 --- a/src/hd_bet/paths.py +++ b/src/hd_bet/paths.py @@ -1,4 +1,6 @@ +from __future__ import annotations + import os # please refer to the readme on where to get the parameters. Save them in this folder: -folder_with_parameter_files = os.path.join(os.path.expanduser('~'), 'hd-bet_params') +folder_with_parameter_files = os.path.join(os.path.expanduser("~"), "hd-bet_params") diff --git a/src/hd_bet/predict_case.py b/src/hd_bet/predict_case.py index 559c667..f7e2268 100755 --- a/src/hd_bet/predict_case.py +++ b/src/hd_bet/predict_case.py @@ -1,14 +1,18 @@ -import torch +from __future__ import annotations + import numpy as np +import torch def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): - if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): + if not (isinstance(shape_must_be_divisible_by, (list, tuple))): shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 shp = patient.shape - new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], - shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], - shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] + new_shp = [ + shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], + shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], + shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2], + ] for i in range(len(shp)): if shp[i] % shape_must_be_divisible_by[i] == 0: new_shp[i] -= shape_must_be_divisible_by[i] @@ -19,28 +23,41 @@ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): shape = tuple(list(image.shape)) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) + new_shape = tuple( + np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) + ) if pad_value is None: if len(shape) == 2: - pad_value = image[0,0] + pad_value = image[0, 0] elif len(shape) == 3: pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") res = np.ones(list(new_shape), dtype=image.dtype) * pad_value if len(shape) == 2: - res[0:0+int(shape[0]), 0:0+int(shape[1])] = image + res[0 : 0 + int(shape[0]), 0 : 0 + int(shape[1])] = image elif len(shape) == 3: - res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image + res[0 : 0 + int(shape[0]), 0 : 0 + int(shape[1]), 0 : 0 + int(shape[2])] = image return res -def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, - new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): +def predict_case_3D_net( + net, + patient_data, + do_mirroring, + num_repeats, + BATCH_SIZE=None, + new_shape_must_be_divisible_by=16, + min_size=None, + main_device=0, + mirror_axes=(2, 3, 4), +): with torch.no_grad(): pad_res = [] for i in range(patient_data.shape[0]): - t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) + t, old_shape = pad_patient_3D( + patient_data[i], new_shape_must_be_divisible_by, min_size + ) pad_res.append(t[None]) patient_data = np.vstack(pad_res) @@ -56,7 +73,7 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE a = torch.rand(data.shape).float() - if main_device == 'cpu': + if main_device == "cpu": pass else: a = a.cuda(main_device) @@ -72,7 +89,6 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE do_stuff = False if m == 0: do_stuff = True - pass if m == 1 and (4 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, :, :, ::-1] @@ -91,13 +107,20 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, :] - if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + if ( + m == 7 + and (2 in mirror_axes) + and (3 in mirror_axes) + and (4 in mirror_axes) + ): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] if do_stuff: _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) - p = net(a) # np.copy is necessary because ::-1 creates just a view i think + p = net( + a + ) # np.copy is necessary because ::-1 creates just a view i think p = p.data.cpu().numpy() if m == 0: @@ -114,11 +137,18 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE p = p[:, :, ::-1, :, ::-1] if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): p = p[:, :, ::-1, ::-1, :] - if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + if ( + m == 7 + and (2 in mirror_axes) + and (3 in mirror_axes) + and (4 in mirror_axes) + ): p = p[:, :, ::-1, ::-1, ::-1] all_preds.append(p) - stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] + stacked = np.vstack(all_preds)[ + :, :, : old_shape[0], : old_shape[1], : old_shape[2] + ] predicted_segmentation = stacked.mean(0).argmax(0) uncertainty = stacked.var(0) bayesian_predictions = stacked diff --git a/src/hd_bet/run.py b/src/hd_bet/run.py index 8c6f08d..7830e52 100755 --- a/src/hd_bet/run.py +++ b/src/hd_bet/run.py @@ -1,12 +1,21 @@ -import torch -import numpy as np -import SimpleITK as sitk -from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti -from HD_BET.predict_case import predict_case_3D_net +from __future__ import annotations + import imp -from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters import os -import HD_BET + +import numpy as np +import SimpleITK as sitk +import torch + +import hd_bet +from hd_bet.data_loading import load_and_preprocess, save_segmentation_nifti +from hd_bet.predict_case import predict_case_3D_net +from hd_bet.utils import ( + SetNetworkToVal, + get_params_fname, + maybe_download_parameters, + postprocess_prediction, +) def apply_bet(img, bet, out_fname): @@ -19,8 +28,18 @@ def apply_bet(img, bet, out_fname): sitk.WriteImage(out, out_fname) -def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, - postprocess=False, do_tta=True, keep_mask=True, overwrite=True, bet=False): +def run_hd_bet( + mri_fnames, + output_fnames, + mode="accurate", + config_file=os.path.join(hd_bet.__path__[0], "config.py"), + device=0, + postprocess=False, + do_tta=True, + keep_mask=True, + overwrite=True, + bet=False, +): """ :param mri_fnames: str or list/tuple of str @@ -37,23 +56,27 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j list_of_param_files = [] - if mode == 'fast': + if mode == "fast": params_file = get_params_fname(0) maybe_download_parameters(0) list_of_param_files.append(params_file) - elif mode == 'accurate': + elif mode == "accurate": for i in range(5): params_file = get_params_fname(i) maybe_download_parameters(i) list_of_param_files.append(params_file) else: - raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) + raise ValueError( + "Unknown value for mode: %s. Expected: fast or accurate" % mode + ) - assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" + assert all( + [os.path.isfile(i) for i in list_of_param_files] + ), "Could not find parameter files" - cf = imp.load_source('cf', config_file) + cf = imp.load_source("cf", config_file) cf = cf.config() net, _ = cf.get_network(cf.val_use_train_mode, None) @@ -68,7 +91,9 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j if not isinstance(output_fnames, (list, tuple)): output_fnames = [output_fnames] - assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" + assert len(mri_fnames) == len( + output_fnames + ), "mri_fnames and output_fnames must have the same length" params = [] for p in list_of_param_files: @@ -76,7 +101,10 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j for in_fname, out_fname in zip(mri_fnames, output_fnames): mask_fname = out_fname[:-7] + "_mask.nii.gz" - if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): + if overwrite or ( + not (os.path.isfile(mask_fname) and keep_mask) + or not os.path.isfile(out_fname) + ): print("File:", in_fname) print("preprocessing...") try: @@ -96,9 +124,17 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j net.load_state_dict(p) net.eval() net.apply(SetNetworkToVal(False, False)) - _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, - cf.val_batch_size, cf.net_input_must_be_divisible_by, - cf.val_min_size, device, cf.da_mirror_axes) + _, _, softmax_pred, _ = predict_case_3D_net( + net, + data, + do_tta, + cf.val_num_repeats, + cf.val_batch_size, + cf.net_input_must_be_divisible_by, + cf.val_min_size, + device, + cf.da_mirror_axes, + ) softmax_preds.append(softmax_pred[None]) seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) @@ -113,5 +149,3 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j if not keep_mask: os.remove(mask_fname) - - diff --git a/src/hd_bet/utils.py b/src/hd_bet/utils.py index f70f389..a918f7c 100755 --- a/src/hd_bet/utils.py +++ b/src/hd_bet/utils.py @@ -1,10 +1,14 @@ +from __future__ import annotations + +import os from urllib.request import urlopen -import torch -from torch import nn + import numpy as np +import torch from skimage.morphology import label -import os -from HD_BET.paths import folder_with_parameter_files +from torch import nn + +from hd_bet.paths import folder_with_parameter_files def get_params_fname(fold): @@ -33,7 +37,7 @@ def maybe_download_parameters(fold=0, force_overwrite=False): url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold print("Downloading", url, "...") data = urlopen(url).read() - with open(out_filename, 'wb') as f: + with open(out_filename, "wb") as f: f.write(data) @@ -52,18 +56,25 @@ def softmax_helper(x): return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) -class SetNetworkToVal(object): +class SetNetworkToVal: def __init__(self, use_dropout_sampling=False, norm_use_average=True): self.norm_use_average = norm_use_average self.use_dropout_sampling = use_dropout_sampling def __call__(self, module): - if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): + if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): module.train(self.use_dropout_sampling) - elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ - isinstance(module, nn.InstanceNorm1d) \ - or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ - isinstance(module, nn.BatchNorm1d): + elif isinstance( + module, + ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + ), + ): module.train(not self.norm_use_average) @@ -83,9 +94,13 @@ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): l = os.path.join else: l = lambda x, y: y - res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) - and (prefix is None or i.startswith(prefix)) - and (suffix is None or i.endswith(suffix))] + res = [ + l(folder, i) + for i in os.listdir(folder) + if os.path.isdir(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix)) + ] if sort: res.sort() return res @@ -96,9 +111,13 @@ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): l = os.path.join else: l = lambda x, y: y - res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) - and (prefix is None or i.startswith(prefix)) - and (suffix is None or i.endswith(suffix))] + res = [ + l(folder, i) + for i in os.listdir(folder) + if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix)) + ] if sort: res.sort() return res @@ -109,6 +128,6 @@ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): def maybe_mkdir_p(directory): splits = directory.split("/")[1:] - for i in range(0, len(splits)): - if not os.path.isdir(os.path.join("/", *splits[:i+1])): - os.mkdir(os.path.join("/", *splits[:i+1])) + for i in range(len(splits)): + if not os.path.isdir(os.path.join("/", *splits[: i + 1])): + os.mkdir(os.path.join("/", *splits[: i + 1]))