diff --git a/DLMUSE/__main__.py b/DLMUSE/__main__.py index 396a086..7f8365e 100644 --- a/DLMUSE/__main__.py +++ b/DLMUSE/__main__.py @@ -1,29 +1,23 @@ import argparse -import json -import os -from pathlib import Path -import shutil -import sys import warnings -import torch - -from .utils import prepare_data_folder, rename_and_copy_files +from .dlmuse_pipeline import run_pipeline warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=UserWarning) VERSION = 1.0 + def main() -> None: - prog="DLMUSE" + prog = "DLMUSE" parser = argparse.ArgumentParser( prog=prog, description="DLMUSE - MUlti-atlas region Segmentation utilizing Ensembles of registration algorithms and parameters.", usage=""" DLMUSE v{VERSION} Segmentation of the brain into MUSE ROIs from the Nifti image (.nii.gz) of the LPS oriented Intra Cranial Volume (ICV - see DLICV method). - + Required arguments: [-i, --in_dir] The filepath of the input directory [-o, --out_dir] The filepath of the output directory @@ -36,8 +30,10 @@ def main() -> None: -o /path/to/output \ -device cpu|cuda|mps - """.format(VERSION=VERSION), - add_help=False + """.format( + VERSION=VERSION + ), + add_help=False, ) # Required Arguments @@ -53,7 +49,7 @@ def main() -> None: required=True, help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).", ) - + # Optional Arguments parser.add_argument( "-device", @@ -99,7 +95,7 @@ def main() -> None: action="store_true", required=False, default=False, - help="Set this flag to clear any cached models before running. This is recommended if a previous download failed." + help="Set this flag to clear any cached models before running. This is recommended if a previous download failed.", ) parser.add_argument( "--disable_tta", @@ -109,7 +105,8 @@ def main() -> None: help="[nnUnet Arg] Set this flag to disable test time data augmentation in the form of mirroring. " "Faster, but less accurate inference. Not recommended.", ) - ### DEPRECIATED #### + + # DEPRECATED # parser.add_argument( # "-m", # type=str, @@ -208,158 +205,32 @@ def main() -> None: required=False, default=0, help="[nnUnet Arg] If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 " - "can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set " + "can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set " "-num_parts 5 and use -part_id 0, 1, 2, 3 and 4. Note: You are yourself responsible to make these run on separate GPUs! " "Use CUDA_VISIBLE_DEVICES.", ) - - args = parser.parse_args() - args.f = [args.f] - - if args.clear_cache: - shutil.rmtree(os.path.join( - Path(__file__).parent, - "nnunet_results" - )) - shutil.rmtree(os.path.join( - Path(__file__).parent, - ".cache" - )) - if not args.i or not args.o: - print("Cache cleared and missing either -i / -o. Exiting.") - sys.exit(0) - - if not args.i or not args.o: - parser.error("The following arguments are required: -i, -o") - - # data conversion - src_folder = args.i # input folder - - if not os.path.exists(args.o): # create output folder if it does not exist - os.makedirs(args.o) - - des_folder = os.path.join(args.o, "renamed_image") - - # check if -i argument is a folder, list (csv), or a single file (nii.gz) - if os.path.isdir(args.i): # if args.i is a directory - src_folder = args.i - prepare_data_folder(des_folder) - rename_dic, rename_back_dict = rename_and_copy_files(src_folder, des_folder) - datalist_file = os.path.join(des_folder, "renaming.json") - with open(datalist_file, "w", encoding="utf-8") as f: - json.dump(rename_dic, f, ensure_ascii=False, indent=4) - print(f"Renaming dic is saved to {datalist_file}") - else: - print("Input directory not found. Exiting DLMUSE.") - sys.exit() - - model_folder = os.path.join( - Path(__file__).parent, - "nnunet_results", - "Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/" - % (args.d, args.d, args.c), + run_pipeline( + args.in_dir, + args.out_dir, + args.device, + args.clear_cache, + args.d, + args.part_id, + args.num_parts, + args.step_size, + args.disable_tta, + args.verbose, + args.disable_progress_bar, + args.chk, + args.save_probabilities, + args.continue_prediction, + args.npp, + args.nps, + args.prev_stage_predictions, ) - # Check if model exists. If not exist, download using HuggingFace - print(f"Using model folder: {model_folder}") - if not os.path.exists(model_folder): - # HF download model - print("DLMUSE model not found, downloading...") - - from huggingface_hub import snapshot_download - local_src = Path(__file__).parent - snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src) - - print("DLMUSE model has been successfully downloaded!") - else: - print("Loading the model...") - - prepare_data_folder(des_folder) - - assert ( - args.part_id < args.num_parts - ), "part_id < num_parts. Please see nnUNetv2_predict -h." - - assert args.device in [ - "cpu", - "cuda", - "mps", - ], f"-device must be either cpu, mps or cuda. Got: {args.device}." - - if args.device == "cpu": - import multiprocessing - # use half of the available threads in the system. - torch.set_num_threads(multiprocessing.cpu_count() // 2) - device = torch.device("cpu") - print("Running in CPU mode.") - elif args.device == "cuda": - # multithreading in torch doesn't help nnU-Net if run on GPU - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - device = torch.device("cuda") - print("Running in CUDA mode.") - else: - device = torch.device("mps") - print("Running in MPS mode.") - - # exports for nnunetv2 purposes - os.environ["nnUNet_raw"] = "/nnunet_raw/" - os.environ["nnUNet_preprocessed"] = "/nnunet_preprocessed" - os.environ["nnUNet_results"] = ( - "/nnunet_results" # where model will be located (fetched from HF) - ) - - from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor - - # Initialize nnUnetPredictor - predictor = nnUNetPredictor( - tile_step_size=args.step_size, - use_gaussian=True, - use_mirroring=not args.disable_tta, - perform_everything_on_device=True, - device=device, - verbose=args.verbose, - verbose_preprocessing=args.verbose, - allow_tqdm=not args.disable_progress_bar, - ) - - # Retrieve the model and it's weight - predictor.initialize_from_trained_model_folder( - model_folder, args.f, checkpoint_name=args.chk - ) - - # Final prediction - predictor.predict_from_files( - des_folder, - args.o, - save_probabilities=args.save_probabilities, - overwrite=not args.continue_prediction, - num_processes_preprocessing=args.npp, - num_processes_segmentation_export=args.nps, - folder_with_segs_from_prev_stage=args.prev_stage_predictions, - num_parts=args.num_parts, - part_id=args.part_id, - ) - - # After prediction, convert the image name back to original - files_folder = args.o - - for filename in os.listdir(files_folder): - if filename.endswith(".nii.gz"): - original_name = rename_back_dict[filename] - os.rename( - os.path.join(files_folder, filename), - os.path.join(files_folder, original_name), - ) - # Remove the (temporary) des_folder directory - if os.path.exists(des_folder): - shutil.rmtree(des_folder) - - print("Inference Process Done!") - - if __name__ == "__main__": main() diff --git a/DLMUSE/dlmuse_pipeline.py b/DLMUSE/dlmuse_pipeline.py new file mode 100644 index 0000000..d51aa52 --- /dev/null +++ b/DLMUSE/dlmuse_pipeline.py @@ -0,0 +1,175 @@ +import json +import os +import shutil +import sys +from pathlib import Path +from typing import Optional + +import torch + +from .utils import prepare_data_folder, rename_and_copy_files + + +def run_pipeline( + in_dir: str, + out_dir: str, + device: str, + clear_cache: bool = False, + d: str = "901", + part_id: int = 0, + num_parts: int = 1, + step_size: float = 0.5, + disable_tta: bool = False, + verbose: bool = False, + disable_progress_bar: bool = False, + chk: bool = False, + save_probabilities: bool = False, + continue_prediction: bool = False, + npp: int = 2, + nps: int = 2, + prev_stage_predictions: Optional[str] = None, +) -> None: + """ + Run dlmuse pipeline function + :param in_dir: The input directory + :type in_dir: str + :param out_dir: The output directory + :type out_dir: str + :param device: cpu/cuda/mps + :type device: str + + Any other argument is not needed for 99% of you. + Devs should see the code + + :rtype: None + """ + f = [0] + if clear_cache: + shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results")) + shutil.rmtree(os.path.join(Path(__file__).parent, ".cache")) + if not in_dir or not out_dir: + print("Cache cleared and missing either -i / -o. Exiting.") + sys.exit(0) + + if not in_dir or not out_dir: + print("The following arguments are required: -i, -o") + sys.exit(0) + + # data conversion + src_folder = in_dir # input folder + if not os.path.exists(out_dir): # create output folder if it does not exist + os.makedirs(out_dir) + + des_folder = os.path.join(out_dir, "renamed_image") + + # check if -i argument is a folder, list (csv), or a single file (nii.gz) + if os.path.isdir(in_dir): # if args.i is a directory + src_folder = in_dir + prepare_data_folder(des_folder) + rename_dic, rename_back_dict = rename_and_copy_files(src_folder, des_folder) + datalist_file = os.path.join(des_folder, "renaming.json") + with open(datalist_file, "w", encoding="utf-8") as ff: + json.dump(rename_dic, ff, ensure_ascii=False, indent=4) + print(f"Renaming dic is saved to {datalist_file}") + + model_folder = os.path.join( + Path(__file__).parent, + "nnunet_results", + "Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (d, d), + ) + + if clear_cache: + shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results")) + shutil.rmtree(os.path.join(Path(__file__).parent, ".cache")) + + # Check if model exists. If not exist, download using HuggingFace + if not os.path.exists(model_folder): + # HF download model + print("DLICV model not found, downloading...") + + from huggingface_hub import snapshot_download + + local_src = Path(__file__).parent + snapshot_download(repo_id="nichart/DLICV", local_dir=local_src) + print("DLICV model has been successfully downloaded!") + else: + print("Loading the model...") + + prepare_data_folder(out_dir) + + # Check for invalid arguments - advise users to see nnUNetv2 documentation + assert part_id < num_parts, "See nnUNetv2_predict -h." + + assert device in [ + "cpu", + "cuda", + "mps", + ], f"-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}." + + if device == "cpu": + import multiprocessing + + torch.set_num_threads( + multiprocessing.cpu_count() // 2 + ) # use half of the threads (better for PC) + device = torch.device("cpu") + elif device == "cuda": + # multithreading in torch doesn't help nnU-Netv2 if run on GPU + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + device = torch.device("cuda") + else: + device = torch.device("mps") + + # exports for nnunetv2 purposes + os.environ["nnUNet_raw"] = "/nnunet_raw/" + os.environ["nnUNet_preprocessed"] = "/nnunet_preprocessed" + os.environ["nnUNet_results"] = ( + "/nnunet_results" # where model will be located (fetched from HF) + ) + + from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor + + # Initialize nnUnetPredictor + predictor = nnUNetPredictor( + tile_step_size=step_size, + use_gaussian=True, + use_mirroring=not disable_tta, + perform_everything_on_device=True, + device=device, + verbose=verbose, + verbose_preprocessing=verbose, + allow_tqdm=not disable_progress_bar, + ) + + # Retrieve the model and its weight + predictor.initialize_from_trained_model_folder(model_folder, f, checkpoint_name=chk) + + # Final prediction + predictor.predict_from_files( + des_folder, + out_dir, + save_probabilities=save_probabilities, + overwrite=not continue_prediction, + num_processes_preprocessing=npp, + num_processes_segmentation_export=nps, + folder_with_segs_from_prev_stage=prev_stage_predictions, + num_parts=num_parts, + part_id=part_id, + ) + + # After prediction, convert the image name back to original + files_folder = out_dir + + for filename in os.listdir(files_folder): + if filename.endswith(".nii.gz"): + original_name = rename_back_dict[filename] + os.rename( + os.path.join(files_folder, filename), + os.path.join(files_folder, original_name), + ) + # Remove the (temporary) des_folder directory + if os.path.exists(des_folder): + shutil.rmtree(des_folder) + + print("DLICV Process Done!")