Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

24.1.x #182

Merged
merged 5 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 107 additions & 9 deletions deepprep/nextflow/bin/bold_apply_transform_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@
import numpy as np
from pathlib import Path
from scipy import ndimage as ndi
from scipy.sparse import hstack as sparse_hstack
import argparse

from nitransforms.io import get_linear_factory
import templateflow.api as tflow
from multiprocessing import Pool
from bids import BIDSLayout
import json

from sdcflows.utils.tools import ensure_positive_cosines
from sdcflows.transform import grid_bspline_weights
import nitransforms as nt



def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities):
Expand All @@ -32,12 +39,12 @@ def get_preproc_file(subject_id, bids_preproc, bold_orig_file, update_entities):
return Path(bold_t1w_file)


def affine_to_3x3(itk):
matrix = itk[:3, :3]
translation = itk[:3, -1:]
def affine_to_3x3(mtx):
matrix = mtx[:3, :3]
translation = mtx[:3, -1:]
return matrix, translation

def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path):
def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, bold_sdc, pe_info, vsm):
# bold_orig = nib.load(bold_file)
bold_orig_values = bold_orig.slicer[..., frame:frame+1].get_fdata()[..., 0]

Expand All @@ -47,6 +54,9 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b,
warped_mesh_frame = ras2vox_A @ warped_mesh_frame + ras2vox_b
warped_mesh_frame = warped_mesh_frame.reshape(3, *fixed.shape)

if bold_sdc:
warped_mesh_frame[pe_info[0][0], ...] += vsm

# interp values
output = np.zeros(
list(fixed.shape),
Expand All @@ -62,6 +72,11 @@ def apply_hmc_pool(frame, warped_mesh_frame, matrix_frame, ras2vox_A, ras2vox_b,
cval=0.0,
prefilter=True,
)

if bold_sdc:
# jacobian
result *= 1 + np.gradient(vsm, axis=pe_info[0][0])

result = result[..., np.newaxis]
nib.save(nib.Nifti1Image(result, affine=fixed.affine, header=bold_orig_header),
f'{transform_save_path}/t{str(frame).zfill(5)}.nii.gz')
Expand All @@ -73,12 +88,45 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
cmd = f'mri_concat --i {transform_save_path}/* --o {output_path}'
os.system(cmd)

# copy the first frame as boldref
shutil.copy(in_files[0], boldref_path)

# generate .json, it is consistent with T1w.json
boldref_json_path = str(output_path).replace('.nii.gz', '.json')
shutil.copy(t1_json, boldref_json_path)
try:
# copy the first frame as boldref
shutil.copy(in_files[0], boldref_path)
shutil.copy(t1_json, boldref_json_path)
except:
pass


def apply_fieldmap2std(in_coeffs, target_ref_file, fmap_ref_file, transforms):
coefficients = [nib.load(in_coeffs)]
target = nib.load(target_ref_file)
fmapref = nib.load(fmap_ref_file)

warp_matrix = nib.load(transforms[0]).get_fdata()
warp_affine = nib.load(transforms[0]).affine
transform_chain = nt.TransformChain(nt.DenseFieldTransform(nib.Nifti1Image(warp_matrix, warp_affine)))
xfm = nt.linear.load(transforms[1])
transform_chain += xfm
xfm = nt.linear.load(transforms[2])
transform_chain += xfm
reference, _ = ensure_positive_cosines(fmapref)
colmat = sparse_hstack(
[grid_bspline_weights(reference, level) for level in coefficients]
).tocsr()
coefficients = np.hstack(
[level.get_fdata(dtype='float32').reshape(-1) for level in coefficients]
)
fmap_img = nib.Nifti1Image(
np.reshape(colmat @ coefficients, reference.shape[:3]),
reference.affine,
)
fmap_img = transform_chain.apply(fmap_img, reference=target)
fmap_img.header.set_intent('estimate', name='fieldmap Hz')
fmap_img.header.set_data_dtype('float32')
fmap_img.header['cal_max'] = max((abs(fmap_img.dataobj.min()), fmap_img.dataobj.max()))
fmap_img.header['cal_min'] = -fmap_img.header['cal_max']
return fmap_img


if __name__ == '__main__':
Expand Down Expand Up @@ -106,6 +154,8 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
parser.add_argument("--template_space", required=True)
parser.add_argument("--template_resolution", required=True)
parser.add_argument("--nonlinear_file", required=True)
parser.add_argument("--bold_sdc", required=True)
parser.add_argument("--task_id", required=True)
parser.add_argument("--reference", required=False, default=None)
parser.add_argument("--moving", required=False, default=None)
args = parser.parse_args()
Expand Down Expand Up @@ -137,6 +187,33 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
output_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_desc-preproc_bold.nii.gz'
boldref_path = Path(coreg_xfm.parent) / f'{args.bold_id}_space-{args.template_space}_res-{args.template_resolution}_boldref.nii.gz'

# load fieldmap info
if args.bold_sdc:
update_entities = {'suffix': 'bold', 'extension': '.json'}
bold_json = get_preproc_file(args.subject_id, args.bids_dir, bold_file, update_entities)
with open(str(bold_json)) as f:
bold_info = json.load(f)
pe_dir = bold_info["PhaseEncodingDirection"]
ro_time = bold_info["TotalReadoutTime"]

fmap_base_dir = Path(args.work_dir) / 'bold_preprocess' /f'{args.subject_id}_wf' / f'{args.task_id}_wf'
fieldmap_id_txt_path = fmap_base_dir / 'fieldmap_id.txt'
with open(str(fieldmap_id_txt_path)) as f:
fieldmap_id_info = json.load(f)
fieldmap_id = fieldmap_id_info.get(bold_file)

coeff_dir = fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'fix_coeff'
in_coeff = sorted(coeff_dir.glob("*_fieldcoef_fixed.nii.gz"))[0]


fmap_ref_file = str(fmap_base_dir.parent / 'fmap_preproc_wf' / f'wf_{fieldmap_id}' / 'brainextraction_wf' / 'clipper_post' / 'clipped.nii.gz')

# get the coreg.xfm
update_entities = {'mode': 'image', 'suffix': 'xfm', 'extension': '.txt'}
fieldmap_xfm = get_preproc_file(args.subject_id, args.bold_preprocess_dir, bold_file, update_entities)
transforms = [nonlinear_file, coreg_xfm, fieldmap_xfm]
fieldmap = apply_fieldmap2std(in_coeff, fixed_file, fmap_ref_file, transforms)

# Load the fixed file
fixed = nib.load(fixed_file)
vox2ras = fixed.affine
Expand Down Expand Up @@ -181,9 +258,30 @@ def concat_frames(transform_save_path, output_path, boldref_path, t1_json):
ras2vox_bold = np.linalg.inv(bold_orig.affine)
ras2vox_A, ras2vox_b = affine_to_3x3(ras2vox_bold)

# apply fieldmap if available
if args.bold_sdc:
nvols = bold_orig.shape[3] if bold_orig.ndim > 3 else 1

if pe_dir and ro_time:
pe_axis = "ijk".index(pe_dir[0])
pe_flip = pe_dir.endswith("-")

# Nitransforms displacements are positive
source, axcodes = ensure_positive_cosines(bold_orig)
axis_flip = axcodes[pe_axis] in "LPI"

pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols

fmap_hz = fieldmap.get_fdata(dtype='f4')
vsm = fmap_hz * pe_info[0][1]

else:
pe_info = None
vsm = None

args_apply_hmc = []
for i in range(matrix.shape[0]):
args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path])
args_apply_hmc.append([int(i), warped_mesh, matrix[i], ras2vox_A, ras2vox_b, bold_orig, fixed, bold_orig_header, transform_save_path, args.bold_sdc, pe_info, vsm])
pool = Pool(10)
pool.starmap(apply_hmc_pool, args_apply_hmc)
pool.close()
Expand Down
5 changes: 5 additions & 0 deletions deepprep/nextflow/bin/bold_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import shutil
from pathlib import Path
import json

from fmriprep.workflows.fieldmap import init_single_subject_fieldmap_wf
from fmriprep.workflows.bold.base import init_bold_wf
Expand Down Expand Up @@ -235,6 +236,10 @@ def get_bold_func_path(subject_id, bids_preproc, bold_orig_file):
fmap_base_dir = Path(config.execution.work_dir) / f'{subject_id}_wf'
base_dir = fmap_base_dir / f'{args.task_id[0]}_wf'
base_dir.mkdir(parents=True, exist_ok=True)
# save fieldmap id
fieldmap_id_txt = base_dir / 'fieldmap_id.txt'
with open(fieldmap_id_txt, 'w') as f:
json.dump(estimator_map, f)
workflow.connect([
(inputnode, bold_wf, [
("t1w_preproc", "inputnode.t1w_preproc"),
Expand Down
8 changes: 6 additions & 2 deletions deepprep/nextflow/deepprep.nf
Original file line number Diff line number Diff line change
Expand Up @@ -2102,12 +2102,14 @@ process bold_transform_chain {
tuple(val(subject_id), val(bold_id), val(trans), path(subject_boldfile_txt_bold))
val(template_space)
val(template_resolution)
val(bold_sdc)

output:
tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_desc-preproc_bold.nii.gz"))
tuple(val(subject_id), val(bold_id), val("${bold_preprocess_path}/${subject_id}/func/${bold_id}_space-${template_space}_res-${template_resolution}_boldref.nii.gz"))

script:
task_id = bold_id.split('task-')[1].split('_')[0]
script_py = "bold_apply_transform_chain.py"

"""
Expand All @@ -2120,7 +2122,9 @@ process bold_transform_chain {
--subject_boldfile_txt_bold ${subject_boldfile_txt_bold} \
--template_space ${template_space} \
--template_resolution ${template_resolution} \
--nonlinear_file ${trans}
--nonlinear_file ${trans} \
--bold_sdc ${bold_sdc} \
--task_id ${task_id}
"""
}

Expand Down Expand Up @@ -2976,7 +2980,7 @@ workflow bold_wf {
(t1_norigid_nii, norm_norigid_nii, trans) = bold_synthmorph_joint(subjects_dir, bold_preprocess_path, synthmorph_home, bold_synthmorph_joint_input, synthmorph_model_path, template_space, device, gpu_lock)

bold_transform_chain_input = subject_id_boldfile_id.groupTuple(sort: true).join(trans, by:[0]).transpose().join(subject_boldfile_txt_bold_pre_process, by: [0, 1])
(preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution)
(preproc_bold, boldref_file) = bold_transform_chain(bids_dir, bold_preprocess_path, work_dir, bold_transform_chain_input, template_space, template_resolution, bold_sdc)
}

do_bold_qc = 'TRUE'
Expand Down