diff --git a/nibabies/workflows/anatomical/base.py b/nibabies/workflows/anatomical/base.py index a8ab39cd..78a5f773 100644 --- a/nibabies/workflows/anatomical/base.py +++ b/nibabies/workflows/anatomical/base.py @@ -348,9 +348,11 @@ def init_infant_anat_wf( name='deriv_buffer', ) if derivatives: - wf.connect( - coregistration_wf, 'outputnode.t1w2t2w_xfm', coreg_deriv_wf, 'inputnode.t1w2t2w_xfm' - ) + wf.connect([ + (coregistration_wf, coreg_deriv_wf, [('outputnode.t1w2t2w_xfm', 'inputnode.t1w2t2w_xfm')]), + (t1w_preproc_wf, coreg_deriv_wf, [('outputnode.anat_preproc', 'inputnode.t1w_ref')]), + (t2w_preproc_wf, coreg_deriv_wf, [('outputnode.anat_preproc', 'inputnode.t2w_ref')]), + ]) # Derivative mask is present if derivatives.mask: diff --git a/nibabies/workflows/anatomical/registration.py b/nibabies/workflows/anatomical/registration.py index 064af0a5..f977cacb 100644 --- a/nibabies/workflows/anatomical/registration.py +++ b/nibabies/workflows/anatomical/registration.py @@ -240,11 +240,13 @@ def init_coregister_derivatives_wf( workflow = pe.Workflow(name=name) inputnode = pe.Node( niu.IdentityInterface( - fields=['t1w_ref', 't2w_ref', 't1w_mask', 't1w_aseg', 't2w_aseg', 't1w2t2w_xfm'] + fields=['t1w_ref', 't2w_ref', 't1w2t2w_xfm', 't1w_mask', 't1w_aseg', 't2w_aseg'] ), name='inputnode', ) - outputnode = pe.Node(niu.IdentityInterface(fields=['t2w_mask', 't1w_aseg']), name='outputnode') + outputnode = pe.Node( + niu.IdentityInterface(fields=['t2w_mask', 't1w_aseg', 't2w_aseg']), name='outputnode' + ) if t1w_mask: t1wmask2t2w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t1wmask2t2w') @@ -259,18 +261,18 @@ def init_coregister_derivatives_wf( # fmt:on if t1w_aseg: # fmt:off - t1waseg2t2w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t2wmask2t1w') + t1waseg2t2w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t1waseg2t2w') workflow.connect([ (inputnode, t1waseg2t2w, [ - ('t2w_aseg', 'input_image'), + ('t1w_aseg', 'input_image'), ('t1w2t2w_xfm', 'transforms'), - ('t1w_ref', 'reference_image')]), - (t1waseg2t2w, outputnode, [('output_image', 't1w_aseg')]) + ('t2w_ref', 'reference_image')]), + (t1waseg2t2w, outputnode, [('output_image', 't2w_aseg')]) ]) # fmt:on if t2w_aseg: # fmt:off - t2waseg2t1w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t2wmask2t1w') + t2waseg2t1w = pe.Node(ApplyTransforms(interpolation="MultiLabel"), name='t2waseg2t1w') workflow.connect([ (inputnode, t2waseg2t1w, [ ('t2w_aseg', 'input_image'),