Skip to content

Commit

Permalink
Created DLMUSE pipelien file for better import | arguments fixed and …
Browse files Browse the repository at this point in the history
…working with just the 3 basic(i, o, d)
  • Loading branch information
spirosmaggioros committed Dec 11, 2024
1 parent bc678dc commit 0306098
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 160 deletions.
191 changes: 31 additions & 160 deletions DLMUSE/__main__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
import argparse
import json
import os
from pathlib import Path
import shutil
import sys
import warnings

import torch

from .utils import prepare_data_folder, rename_and_copy_files
from .dlmuse_pipeline import run_pipeline

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

VERSION = 1.0


def main() -> None:
prog="DLMUSE"
prog = "DLMUSE"
parser = argparse.ArgumentParser(
prog=prog,
description="DLMUSE - MUlti-atlas region Segmentation utilizing Ensembles of registration algorithms and parameters.",
usage="""
DLMUSE v{VERSION}
Segmentation of the brain into MUSE ROIs from the Nifti image (.nii.gz) of the LPS oriented Intra Cranial Volume (ICV - see DLICV method).
Required arguments:
[-i, --in_dir] The filepath of the input directory
[-o, --out_dir] The filepath of the output directory
Expand All @@ -36,8 +30,10 @@ def main() -> None:
-o /path/to/output \
-device cpu|cuda|mps
""".format(VERSION=VERSION),
add_help=False
""".format(
VERSION=VERSION
),
add_help=False,
)

# Required Arguments
Expand All @@ -53,7 +49,7 @@ def main() -> None:
required=True,
help="[REQUIRED] Output folder for the segmentation results in Nifti format (nii.gz).",
)

# Optional Arguments
parser.add_argument(
"-device",
Expand Down Expand Up @@ -99,7 +95,7 @@ def main() -> None:
action="store_true",
required=False,
default=False,
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed."
help="Set this flag to clear any cached models before running. This is recommended if a previous download failed.",
)
parser.add_argument(
"--disable_tta",
Expand All @@ -109,7 +105,8 @@ def main() -> None:
help="[nnUnet Arg] Set this flag to disable test time data augmentation in the form of mirroring. "
"Faster, but less accurate inference. Not recommended.",
)
### DEPRECIATED ####

# DEPRECATED
# parser.add_argument(
# "-m",
# type=str,
Expand Down Expand Up @@ -208,158 +205,32 @@ def main() -> None:
required=False,
default=0,
help="[nnUnet Arg] If multiple nnUNetv2_predict exist, which one is this? IDs start with 0 "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"can end with num_parts - 1. So when you submit 5 nnUNetv2_predict calls you need to set "
"-num_parts 5 and use -part_id 0, 1, 2, 3 and 4. Note: You are yourself responsible to make these run on separate GPUs! "
"Use CUDA_VISIBLE_DEVICES.",
)



args = parser.parse_args()
args.f = [args.f]

if args.clear_cache:
shutil.rmtree(os.path.join(
Path(__file__).parent,
"nnunet_results"
))
shutil.rmtree(os.path.join(
Path(__file__).parent,
".cache"
))
if not args.i or not args.o:
print("Cache cleared and missing either -i / -o. Exiting.")
sys.exit(0)

if not args.i or not args.o:
parser.error("The following arguments are required: -i, -o")

# data conversion
src_folder = args.i # input folder

if not os.path.exists(args.o): # create output folder if it does not exist
os.makedirs(args.o)

des_folder = os.path.join(args.o, "renamed_image")

# check if -i argument is a folder, list (csv), or a single file (nii.gz)
if os.path.isdir(args.i): # if args.i is a directory
src_folder = args.i
prepare_data_folder(des_folder)
rename_dic, rename_back_dict = rename_and_copy_files(src_folder, des_folder)
datalist_file = os.path.join(des_folder, "renaming.json")
with open(datalist_file, "w", encoding="utf-8") as f:
json.dump(rename_dic, f, ensure_ascii=False, indent=4)
print(f"Renaming dic is saved to {datalist_file}")
else:
print("Input directory not found. Exiting DLMUSE.")
sys.exit()

model_folder = os.path.join(
Path(__file__).parent,
"nnunet_results",
"Dataset%s_Task%s_DLMUSEV2/nnUNetTrainer__nnUNetPlans__%s/"
% (args.d, args.d, args.c),
run_pipeline(
args.in_dir,
args.out_dir,
args.device,
args.clear_cache,
args.d,
args.part_id,
args.num_parts,
args.step_size,
args.disable_tta,
args.verbose,
args.disable_progress_bar,
args.chk,
args.save_probabilities,
args.continue_prediction,
args.npp,
args.nps,
args.prev_stage_predictions,
)


# Check if model exists. If not exist, download using HuggingFace
print(f"Using model folder: {model_folder}")
if not os.path.exists(model_folder):
# HF download model
print("DLMUSE model not found, downloading...")

from huggingface_hub import snapshot_download
local_src = Path(__file__).parent
snapshot_download(repo_id="nichart/DLMUSE", local_dir=local_src)

print("DLMUSE model has been successfully downloaded!")
else:
print("Loading the model...")

prepare_data_folder(des_folder)

assert (
args.part_id < args.num_parts
), "part_id < num_parts. Please see nnUNetv2_predict -h."

assert args.device in [
"cpu",
"cuda",
"mps",
], f"-device must be either cpu, mps or cuda. Got: {args.device}."

if args.device == "cpu":
import multiprocessing
# use half of the available threads in the system.
torch.set_num_threads(multiprocessing.cpu_count() // 2)
device = torch.device("cpu")
print("Running in CPU mode.")
elif args.device == "cuda":
# multithreading in torch doesn't help nnU-Net if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
device = torch.device("cuda")
print("Running in CUDA mode.")
else:
device = torch.device("mps")
print("Running in MPS mode.")

# exports for nnunetv2 purposes
os.environ["nnUNet_raw"] = "/nnunet_raw/"
os.environ["nnUNet_preprocessed"] = "/nnunet_preprocessed"
os.environ["nnUNet_results"] = (
"/nnunet_results" # where model will be located (fetched from HF)
)

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# Initialize nnUnetPredictor
predictor = nnUNetPredictor(
tile_step_size=args.step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
perform_everything_on_device=True,
device=device,
verbose=args.verbose,
verbose_preprocessing=args.verbose,
allow_tqdm=not args.disable_progress_bar,
)

# Retrieve the model and it's weight
predictor.initialize_from_trained_model_folder(
model_folder, args.f, checkpoint_name=args.chk
)

# Final prediction
predictor.predict_from_files(
des_folder,
args.o,
save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
num_parts=args.num_parts,
part_id=args.part_id,
)

# After prediction, convert the image name back to original
files_folder = args.o

for filename in os.listdir(files_folder):
if filename.endswith(".nii.gz"):
original_name = rename_back_dict[filename]
os.rename(
os.path.join(files_folder, filename),
os.path.join(files_folder, original_name),
)
# Remove the (temporary) des_folder directory
if os.path.exists(des_folder):
shutil.rmtree(des_folder)

print("Inference Process Done!")


if __name__ == "__main__":
main()
Loading

0 comments on commit 0306098

Please sign in to comment.