diff --git a/NiChart_DLMUSE/__main__.py b/NiChart_DLMUSE/__main__.py index 026353a..663232a 100644 --- a/NiChart_DLMUSE/__main__.py +++ b/NiChart_DLMUSE/__main__.py @@ -11,6 +11,7 @@ import threading from .dlmuse_pipeline import run_pipeline + from .utils import ( collect_T1, merge_bids_output_data, @@ -142,95 +143,25 @@ def main() -> None: device = args.device dlicv_extra_args = args.dlicv_args dlmuse_extra_args = args.dlmuse_args + clear_cache = args.clear_cache + bids = args.bids + cores = args.cores print() print("Arguments:") print(args) print() - if not os.path.isdir(out_dir): - print(f"Can't find {out_dir}, creating it...") - # os.system(f"mkdir {out_dir}") - os.mkdir(out_dir) - - elif len(os.listdir(out_dir)) != 0: - print(f"Emptying output folder: {out_dir}...") - for root, dirs, files in os.walk(out_dir): - for f in files: - os.unlink(os.path.join(root, f)) - for d in dirs: - shutil.rmtree(os.path.join(root, d)) - - if args.clear_cache: - os.system("DLICV -i ./dummy -o ./dummy --clear_cache") - os.system("DLMUSE -i ./dummy -o ./dummy --clear_cache") - - working_dir = os.path.join(os.path.abspath(out_dir)) - - # Run pipeline - if args.bids is True: - if int(args.cores) > 1: - collect_T1(in_dir, out_dir) - - no_threads = int(args.cores) - subfolders = split_data("raw_temp_T1", no_threads) - threads = [] - for i in range(len(subfolders)): - curr_out_dir = out_dir + f"/split_{i}" - curr_thread = threading.Thread( - target=run_pipeline, - args=( - subfolders[i], - curr_out_dir, - device, - dlmuse_extra_args, - dlicv_extra_args, - i, - ), - ) - curr_thread.start() - threads.append(curr_thread) - - for t in threads: - t.join() - - merge_bids_output_data(working_dir) - remove_subfolders("raw_temp_T1") - remove_subfolders(out_dir) - else: # No core parallelization - run_pipeline(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args) - - else: # Non-BIDS - if int(args.cores) > 1: - no_threads = int(args.cores) - subfolders = split_data(in_dir, no_threads) - - threads = [] - for i in range(len(subfolders)): - curr_out_dir = out_dir + f"/split_{i}" - curr_thread = threading.Thread( - target=run_pipeline, - args=( - subfolders[i], - curr_out_dir, - device, - dlmuse_extra_args, - dlicv_extra_args, - i, - ), - ) - curr_thread.start() - threads.append(curr_thread) - - for t in threads: - t.join() - - merge_output_data(out_dir) - remove_subfolders(in_dir) - remove_subfolders(out_dir) - else: # No core parallelization - run_pipeline(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args) - + run_pipeline( + in_dir, + out_dir, + device, + dlicv_extra_args, + dlmuse_extra_args, + clear_cache, + bids, + cores + ) if __name__ == "__main__": main() diff --git a/NiChart_DLMUSE/dlmuse_pipeline.py b/NiChart_DLMUSE/dlmuse_pipeline.py index 03f35c9..8b29268 100644 --- a/NiChart_DLMUSE/dlmuse_pipeline.py +++ b/NiChart_DLMUSE/dlmuse_pipeline.py @@ -1,6 +1,8 @@ import logging import os from typing import Any +import shutil +import threading import pkg_resources # type: ignore @@ -9,7 +11,14 @@ from .RelabelROI import apply_relabel_rois from .ReorientImage import apply_reorient_img, apply_reorient_to_init from .SegmentImage import run_dlicv, run_dlmuse -from .utils import make_img_list +from .utils import ( + make_img_list, + collect_T1, + merge_bids_output_data, + merge_output_data, + remove_subfolders, + split_data, +) # Config vars SUFF_LPS = "_LPS.nii.gz" @@ -35,8 +44,120 @@ logger = logging.getLogger(__name__) logging.basicConfig(filename="pipeline.log", encoding="utf-8", level=logging.DEBUG) - def run_pipeline( + in_dir: str, + out_dir: str, + device: str, + dlicv_extra_args: str, + dlmuse_extra_args: str, + clear_cache: bool, + bids: bool, + cores: str +): + """ + NiChart pipeline + + :param in_dir: The input directory + :type in_dir: str + :param out_dir: The output directory + :type out_dir: str + :type device: cpu/cuda/mps + :param device: str + :param dlicv_extra_args: Extra arguments for DLICV API + :type dlicv_extra_args: str + :param dlmuse_extra_args: Extra arguments for DLMUSE API + :type dlmuse_extra_args: str + :param clear_cache: True if cache should be cleared + :type clear_cache: bool + :param bids: True if your input is a bids type folder + :type bids: bool + :param cores: The number of cores(default is 4) + :type cores: str + + :rtype: None + """ + if not os.path.isdir(out_dir): + print(f"Can't find {out_dir}, creating it...") + # os.system(f"mkdir {out_dir}") + os.mkdir(out_dir) + + elif len(os.listdir(out_dir)) != 0: + print(f"Emptying output folder: {out_dir}...") + for root, dirs, files in os.walk(out_dir): + for f in files: + os.unlink(os.path.join(root, f)) + for d in dirs: + shutil.rmtree(os.path.join(root, d)) + if clear_cache: + os.system("DLICV -i ./dummy -o ./dummy --clear_cache") + os.system("DLMUSE -i ./dummy -o ./dummy --clear_cache") + + working_dir = os.path.join(os.path.abspath(out_dir)) + + if bids is True: + if int(cores) > 1: + collect_T1(in_dir, out_dir) + + no_threads = int(cores) + subfolders = split_data("raw_temp_T1", no_threads) + threads = [] + for i in range(len(subfolders)): + curr_out_dir = out_dir + f"/split_{i}" + curr_thread = threading.Thread( + target=run_thread, + args=( + subfolders[i], + curr_out_dir, + device, + dlmuse_extra_args, + dlicv_extra_args, + i, + ), + ) + curr_thread.start() + threads.append(curr_thread) + + for t in threads: + t.join() + + merge_bids_output_data(working_dir) + remove_subfolders("raw_temp_T1") + remove_subfolders(out_dir) + else: # No core parallelization + run_thread(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args) + + else: # Non-BIDS + if int(cores) > 1: + no_threads = int(cores) + subfolders = split_data(in_dir, no_threads) + + threads = [] + for i in range(len(subfolders)): + curr_out_dir = out_dir + f"/split_{i}" + curr_thread = threading.Thread( + target=run_thread, + args=( + subfolders[i], + curr_out_dir, + device, + dlmuse_extra_args, + dlicv_extra_args, + i, + ), + ) + curr_thread.start() + threads.append(curr_thread) + + for t in threads: + t.join() + + merge_output_data(out_dir) + remove_subfolders(in_dir) + remove_subfolders(out_dir) + else: # No core parallelization + run_thread(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args) + +def run_thread( in_data: str, out_dir: str, device: str, @@ -46,7 +167,7 @@ def run_pipeline( progress_bar: Any = None, ) -> None: """ - NiChart pipeline + Run a thread of the pipeline :param in_data: the input directory :type in_data: str