Skip to content

Commit

Permalink
Changed code organization | Less in main, more on pipeline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
spirosmaggioros committed Dec 11, 2024
1 parent 9a288ec commit b08c40f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 86 deletions.
97 changes: 14 additions & 83 deletions NiChart_DLMUSE/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import threading

from .dlmuse_pipeline import run_pipeline

from .utils import (
collect_T1,
merge_bids_output_data,
Expand Down Expand Up @@ -142,95 +143,25 @@ def main() -> None:
device = args.device
dlicv_extra_args = args.dlicv_args
dlmuse_extra_args = args.dlmuse_args
clear_cache = args.clear_cache
bids = args.bids
cores = args.cores

print()
print("Arguments:")
print(args)
print()

if not os.path.isdir(out_dir):
print(f"Can't find {out_dir}, creating it...")
# os.system(f"mkdir {out_dir}")
os.mkdir(out_dir)

elif len(os.listdir(out_dir)) != 0:
print(f"Emptying output folder: {out_dir}...")
for root, dirs, files in os.walk(out_dir):
for f in files:
os.unlink(os.path.join(root, f))
for d in dirs:
shutil.rmtree(os.path.join(root, d))

if args.clear_cache:
os.system("DLICV -i ./dummy -o ./dummy --clear_cache")
os.system("DLMUSE -i ./dummy -o ./dummy --clear_cache")

working_dir = os.path.join(os.path.abspath(out_dir))

# Run pipeline
if args.bids is True:
if int(args.cores) > 1:
collect_T1(in_dir, out_dir)

no_threads = int(args.cores)
subfolders = split_data("raw_temp_T1", no_threads)
threads = []
for i in range(len(subfolders)):
curr_out_dir = out_dir + f"/split_{i}"
curr_thread = threading.Thread(
target=run_pipeline,
args=(
subfolders[i],
curr_out_dir,
device,
dlmuse_extra_args,
dlicv_extra_args,
i,
),
)
curr_thread.start()
threads.append(curr_thread)

for t in threads:
t.join()

merge_bids_output_data(working_dir)
remove_subfolders("raw_temp_T1")
remove_subfolders(out_dir)
else: # No core parallelization
run_pipeline(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args)

else: # Non-BIDS
if int(args.cores) > 1:
no_threads = int(args.cores)
subfolders = split_data(in_dir, no_threads)

threads = []
for i in range(len(subfolders)):
curr_out_dir = out_dir + f"/split_{i}"
curr_thread = threading.Thread(
target=run_pipeline,
args=(
subfolders[i],
curr_out_dir,
device,
dlmuse_extra_args,
dlicv_extra_args,
i,
),
)
curr_thread.start()
threads.append(curr_thread)

for t in threads:
t.join()

merge_output_data(out_dir)
remove_subfolders(in_dir)
remove_subfolders(out_dir)
else: # No core parallelization
run_pipeline(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args)

run_pipeline(
in_dir,
out_dir,
device,
dlicv_extra_args,
dlmuse_extra_args,
clear_cache,
bids,
cores
)

if __name__ == "__main__":
main()
127 changes: 124 additions & 3 deletions NiChart_DLMUSE/dlmuse_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os
from typing import Any
import shutil
import threading

import pkg_resources # type: ignore

Expand All @@ -9,7 +11,14 @@
from .RelabelROI import apply_relabel_rois
from .ReorientImage import apply_reorient_img, apply_reorient_to_init
from .SegmentImage import run_dlicv, run_dlmuse
from .utils import make_img_list
from .utils import (
make_img_list,
collect_T1,
merge_bids_output_data,
merge_output_data,
remove_subfolders,
split_data,
)

# Config vars
SUFF_LPS = "_LPS.nii.gz"
Expand All @@ -35,8 +44,120 @@
logger = logging.getLogger(__name__)
logging.basicConfig(filename="pipeline.log", encoding="utf-8", level=logging.DEBUG)


def run_pipeline(
in_dir: str,
out_dir: str,
device: str,
dlicv_extra_args: str,
dlmuse_extra_args: str,
clear_cache: bool,
bids: bool,
cores: str
):
"""
NiChart pipeline
:param in_dir: The input directory
:type in_dir: str
:param out_dir: The output directory
:type out_dir: str
:type device: cpu/cuda/mps
:param device: str
:param dlicv_extra_args: Extra arguments for DLICV API
:type dlicv_extra_args: str
:param dlmuse_extra_args: Extra arguments for DLMUSE API
:type dlmuse_extra_args: str
:param clear_cache: True if cache should be cleared
:type clear_cache: bool
:param bids: True if your input is a bids type folder
:type bids: bool
:param cores: The number of cores(default is 4)
:type cores: str
:rtype: None
"""
if not os.path.isdir(out_dir):
print(f"Can't find {out_dir}, creating it...")
# os.system(f"mkdir {out_dir}")
os.mkdir(out_dir)

elif len(os.listdir(out_dir)) != 0:
print(f"Emptying output folder: {out_dir}...")
for root, dirs, files in os.walk(out_dir):
for f in files:
os.unlink(os.path.join(root, f))
for d in dirs:
shutil.rmtree(os.path.join(root, d))
if clear_cache:
os.system("DLICV -i ./dummy -o ./dummy --clear_cache")
os.system("DLMUSE -i ./dummy -o ./dummy --clear_cache")

working_dir = os.path.join(os.path.abspath(out_dir))

if bids is True:
if int(cores) > 1:
collect_T1(in_dir, out_dir)

no_threads = int(cores)
subfolders = split_data("raw_temp_T1", no_threads)
threads = []
for i in range(len(subfolders)):
curr_out_dir = out_dir + f"/split_{i}"
curr_thread = threading.Thread(
target=run_thread,
args=(
subfolders[i],
curr_out_dir,
device,
dlmuse_extra_args,
dlicv_extra_args,
i,
),
)
curr_thread.start()
threads.append(curr_thread)

for t in threads:
t.join()

merge_bids_output_data(working_dir)
remove_subfolders("raw_temp_T1")
remove_subfolders(out_dir)
else: # No core parallelization
run_thread(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args)

else: # Non-BIDS
if int(cores) > 1:
no_threads = int(cores)
subfolders = split_data(in_dir, no_threads)

threads = []
for i in range(len(subfolders)):
curr_out_dir = out_dir + f"/split_{i}"
curr_thread = threading.Thread(
target=run_thread,
args=(
subfolders[i],
curr_out_dir,
device,
dlmuse_extra_args,
dlicv_extra_args,
i,
),
)
curr_thread.start()
threads.append(curr_thread)

for t in threads:
t.join()

merge_output_data(out_dir)
remove_subfolders(in_dir)
remove_subfolders(out_dir)
else: # No core parallelization
run_thread(in_dir, out_dir, device, dlmuse_extra_args, dlicv_extra_args)

def run_thread(
in_data: str,
out_dir: str,
device: str,
Expand All @@ -46,7 +167,7 @@ def run_pipeline(
progress_bar: Any = None,
) -> None:
"""
NiChart pipeline
Run a thread of the pipeline
:param in_data: the input directory
:type in_data: str
Expand Down

0 comments on commit b08c40f

Please sign in to comment.