Skip to content

Commit

Permalink
24.1.x (#182)
Browse files Browse the repository at this point in the history
* CHG: add fielpmap
  • Loading branch information
Ireneyou33 authored Dec 2, 2024
1 parent 17b07f1 commit ff9a707
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 11 deletions.
116 changes: 107 additions & 9 deletions deepprep/nextflow/bin/bold_apply_transform_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
import numpy as np
from pathlib import Path
from scipy import ndimage as ndi
from scipy.sparse import hstack as sparse_hstack
import argparse

from nitransforms.io import get_linear_factory
import templateflow.api as tflow
from multiprocessing import Pool
from bids import BIDSLayout
import json

from sdcflows.utils.tools import ensure_positive_cosines
from sdcflows.transform import grid_bspline_weights
import nitransforms as nt



def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities):
Expand All @@ -32,12 +39,12 @@ def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities):
return Path(bold_t1w_file)


def affine_to_3x3(itk):
matrix = itk[:3, :3]
translation = itk[:3, -1:]
def affine_to_3x3(mtx):
matrix = mtx[:3, :3]
translation = mtx[:3, -1:]
return matrix, translation

def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path):
def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, bold_sdc, pe_info, vsm):
# bold_orig = nib.load(bold_file)
bold_orig_values = bold_orig.slicer[..., frame:frame+1].get_fdata()[..., 0]

Expand All @@ -47,6 +54,9 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b,
warped_mesh_frame = ras2vox_A @ warped_mesh_frame + ras2vox_b
warped_mesh_frame = warped_mesh_frame.reshape(3, *fixed.shape)

if bold_sdc:
warped_mesh_frame[pe_info[0][0], ...] += vsm

# interp values
output = np.zeros(
list(fixed.shape),
Expand All @@ -62,6 +72,11 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b,
cval=0.0,
prefilter=True,
)

if bold_sdc:
# jacobian
result *= 1 + np.gradient(vsm, axis=pe_info[0][0])

result = result[..., np.newaxis]
nib.save(nib.Nifti1Image(result, affine=fixed.affine, header=bold_orig_header),
f'{transform_save_path}/t{str(frame).zfill(5)}.nii.gz')
Expand All @@ -73,12 +88,45 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
cmd = f'mri_concat --i {transform_save_path}/* --o {output_path}'
os.system(cmd)

# copy the first frame as boldref
shutil.copy(in_files[0], boldref_path)

# generate .json, it is consistent with T1w.json
boldref_json_path = str(output_path).replace('.nii.gz', '.json')
shutil.copy(t1_json, boldref_json_path)
try:
# copy the first frame as boldref
shutil.copy(in_files[0], boldref_path)
shutil.copy(t1_json, boldref_json_path)
except:
pass


def apply_fieldmap2std(in_coeffs, target_ref_file, fmap_ref_file, transforms):
coefficients = [nib.load(in_coeffs)]
target = nib.load(target_ref_file)
fmapref = nib.load(fmap_ref_file)

warp_matrix = nib.load(transforms[0]).get_fdata()
warp_affine = nib.load(transforms[0]).affine
transform_chain = nt.TransformChain(nt.DenseFieldTransform(nib.Nifti1Image(warp_matrix, warp_affine)))
xfm = nt.linear.load(transforms[1])
transform_chain += xfm
xfm = nt.linear.load(transforms[2])
transform_chain += xfm
reference, _ = ensure_positive_cosines(fmapref)
colmat = sparse_hstack(
[grid_bspline_weights(reference, level) for level in coefficients]
).tocsr()
coefficients = np.hstack(
[level.get_fdata(dtype='float32').reshape(-1) for level in coefficients]
)
fmap_img = nib.Nifti1Image(
np.reshape(colmat @ coefficients, reference.shape[:3]),
reference.affine,
)
fmap_img = transform_chain.apply(fmap_img, reference=target)
fmap_img.header.set_intent('estimate', name='fieldmap Hz')
fmap_img.header.set_data_dtype('float32')
fmap_img.header['cal_max'] = max((abs(fmap_img.dataobj.min()), fmap_img.dataobj.max()))
fmap_img.header['cal_min'] = -fmap_img.header['cal_max']
return fmap_img


if __name__ == '__main__':
Expand Down Expand Up @@ -106,6 +154,8 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
parser.add_argument("--template_space", required=True)
parser.add_argument("--template_resolution", required=True)
parser.add_argument("--nonlinear_file", required=True)
parser.add_argument("--bold_sdc", required=True)
parser.add_argument("--task_id", required=True)
parser.add_argument("--reference", required=False, default=None)
parser.add_argument("--moving", required=False, default=None)
args = parser.parse_args()
Expand Down Expand Up @@ -137,6 +187,33 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
output_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_desc-preproc_bold.nii.gz'
boldref_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_boldref.nii.gz'

# load fieldmap info
if args.bold_sdc:
update_entities = {'suffix': 'bold', 'extension': '.json'}
bold_json = get_preproc_file(args.subject_id, args.bids_dir, bold_file, update_entities)
with open(str(bold_json)) as f:
bold_info = json.load(f)
pe_dir = bold_info["PhaseEncodingDirection"]
ro_time = bold_info["TotalReadoutTime"]

fmap_base_dir = Path(args.work_dir) / 'bold_preprocess' /f'{args.subject_id}_wf' / f'{args.task_id}_wf'
fieldmap_id_txt_path = fmap_base_dir / 'fieldmap_id.txt'
with open(str(fieldmap_id_txt_path)) as f:
fieldmap_id_info = json.load(f)
fieldmap_id = fieldmap_id_info.get(bold_file)

coeff_dir = fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'fix_coeff'
in_coeff = sorted(coeff_dir.glob("*_fieldcoef_fixed.nii.gz"))[0]


fmap_ref_file = str(fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'brainextraction_wf' / 'clipper_post' / 'clipped.nii.gz')

# get the coreg.xfm
update_entities = {'mode': 'image', 'suffix': 'xfm', 'extension': '.txt'}
fieldmap_xfm = get_preproc_file(args.subject_id, args.bold_preprocess_dir, bold_file, update_entities)
transforms = [nonlinear_file, coreg_xfm, fieldmap_xfm]
fieldmap = apply_fieldmap2std(in_coeff, fixed_file, fmap_ref_file, transforms)

# Load the fixed file
fixed = nib.load(fixed_file)
vox2ras = fixed.affine
Expand Down Expand Up @@ -181,9 +258,30 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
ras2vox_bold = np.linalg.inv(bold_orig.affine)
ras2vox_A, ras2vox_b = affine_to_3x3(ras2vox_bold)

# apply fieldmap if available
if args.bold_sdc:
nvols = bold_orig.shape[3] if bold_orig.ndim > 3 else 1

if pe_dir and ro_time:
pe_axis = "ijk".index(pe_dir[0])
pe_flip = pe_dir.endswith("-")

# Nitransforms displacements are positive
source, axcodes = ensure_positive_cosines(bold_orig)
axis_flip = axcodes[pe_axis] in "LPI"

pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols

fmap_hz = fieldmap.get_fdata(dtype='f4')
vsm = fmap_hz * pe_info[0][1]

else:
pe_info = None
vsm = None

args_apply_hmc = []
for i in range(matrix.shape[0]):
args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path])
args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, args.bold_sdc, pe_info, vsm])
pool = Pool(10)
pool.starmap(apply_hmc_pool, args_apply_hmc)
pool.close()
Expand Down
5 changes: 5 additions & 0 deletions deepprep/nextflow/bin/bold_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil
from pathlib import Path
import json

from fmriprep.workflows.fieldmap import init_single_subject_fieldmap_wf
from fmriprep.workflows.bold.base import init_bold_wf
Expand Down Expand Up @@ -235,6 +236,10 @@ def get_bold_func_path(subject_id, bids_preproc, bold_orig_file):
fmap_base_dir = Path(config.execution.work_dir) / f'{subject_id}_wf'
base_dir = fmap_base_dir / f'{args.task_id[0]}_wf'
base_dir.mkdir(parents=True, exist_ok=True)
# save fieldmap id
fieldmap_id_txt = base_dir / 'fieldmap_id.txt'
with open(fieldmap_id_txt, 'w') as f:
json.dump(estimator_map, f)
workflow.connect([
(inputnode, bold_wf, [
("t1w_preproc", "inputnode.t1w_preproc"),
Expand Down
8 changes: 6 additions & 2 deletions deepprep/nextflow/deepprep.nf
Original file line number Diff line number Diff line change
Expand Up @@ -2102,12 +2102,14 @@ process bold_transform_chain {
tuple(val(subject_id), val(bold_id), val(trans), path(subject_boldfile_txt_bold))
val(template_space)
val(template_resolution)
val(bold_sdc)

output:
tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_desc-preproc_bold.nii.gz"))
tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_boldref.nii.gz"))

script:
task_id = bold_id.split('task-')[1].split('_')[0]
script_py = "bold_apply_transform_chain.py"

"""
Expand All @@ -2120,7 +2122,9 @@ process bold_transform_chain {
--subject_boldfile_txt_bold ${subject_boldfile_txt_bold} \
--template_space ${template_space} \
--template_resolution ${template_resolution} \
--nonlinear_file ${trans}
--nonlinear_file ${trans} \
--bold_sdc ${bold_sdc} \
--task_id ${task_id}
"""
}

Expand Down Expand Up @@ -2976,7 +2980,7 @@ workflow bold_wf {
(t1_norigid_nii, norm_norigid_nii, trans) = bold_synthmorph_joint(subjects_dir, bold_preprocess_path, synthmorph_home, bold_synthmorph_joint_input, synthmorph_model_path, template_space, device, gpu_lock)

bold_transform_chain_input = subject_id_boldfile_id.groupTuple(sort: true).join(trans, by:[0]).transpose().join(subject_boldfile_txt_bold_pre_process, by: [0, 1])
(preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution)
(preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution, bold_sdc)
}

do_bold_qc = 'TRUE'
Expand Down

0 comments on commit ff9a707

Please sign in to comment.