diff --git a/deepprep/nextflow/bin/bold_apply_transform_chain.py b/deepprep/nextflow/bin/bold_apply_transform_chain.py index d752788..79a0de5 100755 --- a/deepprep/nextflow/bin/bold_apply_transform_chain.py +++ b/deepprep/nextflow/bin/bold_apply_transform_chain.py @@ -7,12 +7,19 @@ import numpy as np from pathlib import Path from scipy import ndimage as ndi +from scipy.sparse import hstack as sparse_hstack import argparse from nitransforms.io import get_linear_factory import templateflow.api as tflow from multiprocessing import Pool from bids import BIDSLayout +import json + +from sdcflows.utils.tools import ensure_positive_cosines +from sdcflows.transform import grid_bspline_weights +import nitransforms as nt + def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities): @@ -32,12 +39,12 @@ def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities): return Path(bold_t1w_file) -def affine_to_3x3(itk): - matrix = itk[:3, :3] - translation = itk[:3, -1:] +def affine_to_3x3(mtx): + matrix = mtx[:3, :3] + translation = mtx[:3, -1:] return matrix, translation -def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path): +def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, bold_sdc, pe_info, vsm): # bold_orig = nib.load(bold_file) bold_orig_values = bold_orig.slicer[..., frame:frame+1].get_fdata()[..., 0] @@ -47,6 +54,9 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, warped_mesh_frame = ras2vox_A @ warped_mesh_frame + ras2vox_b warped_mesh_frame = warped_mesh_frame.reshape(3, *fixed.shape) + if bold_sdc: + warped_mesh_frame[pe_info[0][0], ...] += vsm + # interp values output = np.zeros( list(fixed.shape), @@ -62,6 +72,11 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, cval=0.0, prefilter=True, ) + + if bold_sdc: + # jacobian + result *= 1 + np.gradient(vsm, axis=pe_info[0][0]) + result = result[..., np.newaxis] nib.save(nib.Nifti1Image(result, affine=fixed.affine, header=bold_orig_header), f'{transform_save_path}/t{str(frame).zfill(5)}.nii.gz') @@ -73,12 +88,45 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json): cmd = f'mri_concat --i {transform_save_path}/* --o {output_path}' os.system(cmd) - # copy the first frame as boldref - shutil.copy(in_files[0], boldref_path) - # generate .json, it is consistent with T1w.json boldref_json_path = str(output_path).replace('.nii.gz', '.json') - shutil.copy(t1_json, boldref_json_path) + try: + # copy the first frame as boldref + shutil.copy(in_files[0], boldref_path) + shutil.copy(t1_json, boldref_json_path) + except: + pass + + +def apply_fieldmap2std(in_coeffs, target_ref_file, fmap_ref_file, transforms): + coefficients = [nib.load(in_coeffs)] + target = nib.load(target_ref_file) + fmapref = nib.load(fmap_ref_file) + + warp_matrix = nib.load(transforms[0]).get_fdata() + warp_affine = nib.load(transforms[0]).affine + transform_chain = nt.TransformChain(nt.DenseFieldTransform(nib.Nifti1Image(warp_matrix, warp_affine))) + xfm = nt.linear.load(transforms[1]) + transform_chain += xfm + xfm = nt.linear.load(transforms[2]) + transform_chain += xfm + reference, _ = ensure_positive_cosines(fmapref) + colmat = sparse_hstack( + [grid_bspline_weights(reference, level) for level in coefficients] + ).tocsr() + coefficients = np.hstack( + [level.get_fdata(dtype='float32').reshape(-1) for level in coefficients] + ) + fmap_img = nib.Nifti1Image( + np.reshape(colmat @ coefficients, reference.shape[:3]), + reference.affine, + ) + fmap_img = transform_chain.apply(fmap_img, reference=target) + fmap_img.header.set_intent('estimate', name='fieldmap Hz') + fmap_img.header.set_data_dtype('float32') + fmap_img.header['cal_max'] = max((abs(fmap_img.dataobj.min()), fmap_img.dataobj.max())) + fmap_img.header['cal_min'] = -fmap_img.header['cal_max'] + return fmap_img if __name__ == '__main__': @@ -106,6 +154,8 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json): parser.add_argument("--template_space", required=True) parser.add_argument("--template_resolution", required=True) parser.add_argument("--nonlinear_file", required=True) + parser.add_argument("--bold_sdc", required=True) + parser.add_argument("--task_id", required=True) parser.add_argument("--reference", required=False, default=None) parser.add_argument("--moving", required=False, default=None) args = parser.parse_args() @@ -137,6 +187,33 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json): output_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_desc-preproc_bold.nii.gz' boldref_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_boldref.nii.gz' + # load fieldmap info + if args.bold_sdc: + update_entities = {'suffix': 'bold', 'extension': '.json'} + bold_json = get_preproc_file(args.subject_id, args.bids_dir, bold_file, update_entities) + with open(str(bold_json)) as f: + bold_info = json.load(f) + pe_dir = bold_info["PhaseEncodingDirection"] + ro_time = bold_info["TotalReadoutTime"] + + fmap_base_dir = Path(args.work_dir) / 'bold_preprocess' /f'{args.subject_id}_wf' / f'{args.task_id}_wf' + fieldmap_id_txt_path = fmap_base_dir / 'fieldmap_id.txt' + with open(str(fieldmap_id_txt_path)) as f: + fieldmap_id_info = json.load(f) + fieldmap_id = fieldmap_id_info.get(bold_file) + + coeff_dir = fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'fix_coeff' + in_coeff = sorted(coeff_dir.glob("*_fieldcoef_fixed.nii.gz"))[0] + + + fmap_ref_file = str(fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'brainextraction_wf' / 'clipper_post' / 'clipped.nii.gz') + + # get the coreg.xfm + update_entities = {'mode': 'image', 'suffix': 'xfm', 'extension': '.txt'} + fieldmap_xfm = get_preproc_file(args.subject_id, args.bold_preprocess_dir, bold_file, update_entities) + transforms = [nonlinear_file, coreg_xfm, fieldmap_xfm] + fieldmap = apply_fieldmap2std(in_coeff, fixed_file, fmap_ref_file, transforms) + # Load the fixed file fixed = nib.load(fixed_file) vox2ras = fixed.affine @@ -181,9 +258,30 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json): ras2vox_bold = np.linalg.inv(bold_orig.affine) ras2vox_A, ras2vox_b = affine_to_3x3(ras2vox_bold) + # apply fieldmap if available + if args.bold_sdc: + nvols = bold_orig.shape[3] if bold_orig.ndim > 3 else 1 + + if pe_dir and ro_time: + pe_axis = "ijk".index(pe_dir[0]) + pe_flip = pe_dir.endswith("-") + + # Nitransforms displacements are positive + source, axcodes = ensure_positive_cosines(bold_orig) + axis_flip = axcodes[pe_axis] in "LPI" + + pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols + + fmap_hz = fieldmap.get_fdata(dtype='f4') + vsm = fmap_hz * pe_info[0][1] + + else: + pe_info = None + vsm = None + args_apply_hmc = [] for i in range(matrix.shape[0]): - args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path]) + args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, args.bold_sdc, pe_info, vsm]) pool = Pool(10) pool.starmap(apply_hmc_pool, args_apply_hmc) pool.close() diff --git a/deepprep/nextflow/bin/bold_preprocess.py b/deepprep/nextflow/bin/bold_preprocess.py index 96b074b..21b6c4d 100755 --- a/deepprep/nextflow/bin/bold_preprocess.py +++ b/deepprep/nextflow/bin/bold_preprocess.py @@ -3,6 +3,7 @@ import os import shutil from pathlib import Path +import json from fmriprep.workflows.fieldmap import init_single_subject_fieldmap_wf from fmriprep.workflows.bold.base import init_bold_wf @@ -235,6 +236,10 @@ def get_bold_func_path(subject_id, bids_preproc, bold_orig_file): fmap_base_dir = Path(config.execution.work_dir) / f'{subject_id}_wf' base_dir = fmap_base_dir / f'{args.task_id[0]}_wf' base_dir.mkdir(parents=True, exist_ok=True) + # save fieldmap id + fieldmap_id_txt = base_dir / 'fieldmap_id.txt' + with open(fieldmap_id_txt, 'w') as f: + json.dump(estimator_map, f) workflow.connect([ (inputnode, bold_wf, [ ("t1w_preproc", "inputnode.t1w_preproc"), diff --git a/deepprep/nextflow/deepprep.nf b/deepprep/nextflow/deepprep.nf index deab987..66256b6 100644 --- a/deepprep/nextflow/deepprep.nf +++ b/deepprep/nextflow/deepprep.nf @@ -2102,12 +2102,14 @@ process bold_transform_chain { tuple(val(subject_id), val(bold_id), val(trans), path(subject_boldfile_txt_bold)) val(template_space) val(template_resolution) + val(bold_sdc) output: tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_desc-preproc_bold.nii.gz")) tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_boldref.nii.gz")) script: + task_id = bold_id.split('task-')[1].split('_')[0] script_py = "bold_apply_transform_chain.py" """ @@ -2120,7 +2122,9 @@ process bold_transform_chain { --subject_boldfile_txt_bold ${subject_boldfile_txt_bold} \ --template_space ${template_space} \ --template_resolution ${template_resolution} \ - --nonlinear_file ${trans} + --nonlinear_file ${trans} \ + --bold_sdc ${bold_sdc} \ + --task_id ${task_id} """ } @@ -2976,7 +2980,7 @@ workflow bold_wf { (t1_norigid_nii, norm_norigid_nii, trans) = bold_synthmorph_joint(subjects_dir, bold_preprocess_path, synthmorph_home, bold_synthmorph_joint_input, synthmorph_model_path, template_space, device, gpu_lock) bold_transform_chain_input = subject_id_boldfile_id.groupTuple(sort: true).join(trans, by:[0]).transpose().join(subject_boldfile_txt_bold_pre_process, by: [0, 1]) - (preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution) + (preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution, bold_sdc) } do_bold_qc = 'TRUE'