Skip to content

Commit

Permalink
ENH: --parallel option for b0_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
allemangD committed Mar 26, 2024
1 parent 9b9fff2 commit f513c0e
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/abcd_microstructure_pipelines/masks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import itertools
import logging
import multiprocessing
import os
from pathlib import Path

Expand Down Expand Up @@ -32,7 +34,8 @@ def gen_b0_mean(dwi: Path, bval: Path, bvec: Path, b0_out: Path):
@click.option("--inputs", "-i", required=True, type=Path)
@click.option("--outputs", "-o", required=True, type=Path)
@click.option("--overwrite", is_flag=True)
def gen_masks(inputs: Path, outputs: Path, overwrite: bool):
@click.option("--parallel", "-j", is_flag=True)
def gen_masks(inputs: Path, outputs: Path, overwrite: bool, parallel: bool):
b0_tasks = []
hd_bet_input = []
hd_bet_output = []
Expand All @@ -57,12 +60,21 @@ def gen_masks(inputs: Path, outputs: Path, overwrite: bool):
hd_bet_input.append(str(b0_out))
hd_bet_output.append(str(mask_out))

logging.debug("Generate missing b0_mean")
for dwi, bval, bvec, mask_out in b0_tasks:
gen_b0_mean(dwi, bval, bvec, mask_out)
if parallel:
logging.debug("Generate b0_mean in parallel")
pool = multiprocessing.Pool()
starmap = pool.starmap
else:
logging.debug("Generate b0_mean sequentially")
starmap = itertools.starmap

logging.debug("Generate missing masks")
logging.debug("Generate %s b0_mean", len(b0_tasks))
for _ in starmap(gen_b0_mean, b0_tasks):
pass # just consume the iterator. maybe wrap in tqdm?

logging.debug("Loading HD_BET")
# don't import till now since it takes time to initialize.
import HD_BET.run

logging.debug("Generate %s masks", len(hd_bet_input))
HD_BET.run.run_hd_bet(hd_bet_input, hd_bet_output, overwrite=overwrite)

0 comments on commit f513c0e

Please sign in to comment.