Skip to content

Commit

Permalink
ENH: Add validation when saving CIFTI2 images
Browse files Browse the repository at this point in the history
- Enabled by default, validation will parse the output filename for a valid CIFTI2 extension.
- If found, the intent code of the image will be set. Also, the CIFTI2Header will be check for compliant index maps for the intent code
  • Loading branch information
mgxd committed Oct 2, 2020
1 parent d0bbcc7 commit b55a363
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 38 deletions.
107 changes: 79 additions & 28 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..dataobj_images import DataobjImage
from ..nifti2 import Nifti2Image, Nifti2Header
from ..arrayproxy import reshape_dataobj
from ..volumeutils import Recoder
from warnings import warn


Expand Down Expand Up @@ -90,20 +91,50 @@ class Cifti2HeaderError(Exception):

# "Standard CIFTI Mapping Combinations" within CIFTI-2 spec
# https://www.nitrc.org/forum/attachment.php?attachid=341&group_id=454&forum_id=1955
CIFTI_EXTENSIONS_TO_INTENTS = {
'.dconn': 'NIFTI_INTENT_CONNECTIVITY_DENSE',
'.dtseries': 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES',
'.pconn': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED',
'.ptseries': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES',
'.dscalar': 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS',
'.dlabel': 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS',
'.pscalar': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR',
'.pdconn': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE',
'.dpconn': 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED',
'.pconnseries': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES',
'.pconnscalar': 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR',
'.dfan': 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES',
}
CIFTI_CODES = Recoder((
('dconn', 'NIFTI_INTENT_CONNECTIVITY_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('dtseries', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('pconn', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('ptseries', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_SERIES', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('dscalar', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SCALARS', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('dlabel', 'NIFTI_INTENT_CONNECTIVITY_DENSE_LABELS', (
'CIFTI_INDEX_TYPE_LABELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('pscalar', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('pdconn', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_DENSE', (
'CIFTI_INDEX_TYPE_BRAIN_MODELS', 'CIFTI_INDEX_TYPE_PARCELS',
)),
('dpconn', 'NIFTI_INTENT_CONNECTIVITY_DENSE_PARCELLATED', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('pconnseries', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SERIES', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SERIES',
)),
('pconnscalar', 'NIFTI_INTENT_CONNECTIVITY_PARCELLATED_PARCELLATED_SCALAR', (
'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_PARCELS', 'CIFTI_INDEX_TYPE_SCALARS',
)),
('dfan', 'NIFTI_INTENT_CONNECTIVITY_DENSE_SERIES', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('dfibersamp', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
('dfansamp', 'NIFTI_INTENT_CONNECTIVITY_UNKNOWN', (
'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_SCALARS', 'CIFTI_INDEX_TYPE_BRAIN_MODELS',
)),
), fields=('extension', 'niistring', 'map_types'))


def _value_if_klass(val, klass):
Expand Down Expand Up @@ -1503,32 +1534,52 @@ def get_data_dtype(self):
def set_data_dtype(self, dtype):
self._nifti_header.set_data_dtype(dtype)

def to_filename(self, filename, infer_intent=False):
def to_filename(self, filename, validate=True):
"""
Ensures NIfTI header intent code is set prior to saving.
Parameters
----------
infer_intent : boolean, optional
If ``True``, attempt to infer and set intent code based on filename suffix.
validate : boolean, optional
If ``True``, infer and validate CIFTI type based on filename suffix.
This includes the setting of the NIfTI intent code and checking the ``CIFTI2Matrix``
for the expected IndicesMaps attributes.
If validation fails, an error will be raised instead.
"""
header = self._nifti_header
if infer_intent:
# try to infer intent code based on filename suffix
intent = _infer_intent_from_filename(filename)
if intent is not None:
header.set_intent(intent)
nheader = self._nifti_header
# try to infer intent code based on filename suffix
if validate:
ext = _extract_cifti_extension(filename)
try:
CIFTI_CODES.extension[ext]
except KeyError as err:
raise KeyError(
f"Validation failed: No information for extension {ext} available"
) from err
intent = CIFTI_CODES.niistring[ext]
nheader.set_intent(intent)
# validate matrix indices
for idx, mtype in enumerate(CIFTI_CODES.map_types[ext]):
try:
assert self.header.matrix.get_index_map(idx).indices_map_to_data_type == mtype
except Exception:
raise Cifti2HeaderError(
f"Validation failed: Cifti2Matrix index map {idx} does "
f"not match expected type {mtype}"
)
# if intent code is not set, default to unknown
if header.get_intent()[0] == 'none':
header.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
if nheader.get_intent()[0] == 'none':
nheader.set_intent('NIFTI_INTENT_CONNECTIVITY_UNKNOWN')
super().to_filename(filename)


def _infer_intent_from_filename(filename):
def _extract_cifti_extension(filename):
"""Parses output filename for common suffixes and fetches corresponding intent code"""
from pathlib import Path
ext = Path(filename).suffixes[0]
return CIFTI_EXTENSIONS_TO_INTENTS.get(ext)
_suf = Path(filename).suffixes
# select second to last if possible (.<suffix>.nii)
ext = _suf[-2] if len(_suf) >= 2 else _suf[0]
return ext.lstrip('.')


load = Cifti2Image.from_filename
Expand Down
62 changes: 52 additions & 10 deletions nibabel/cifti2/tests/test_new_cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_dtseries():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.dtseries.nii', infer_intent=True)
ci.save(img, 'test.dtseries.nii')
img2 = nib.load('test.dtseries.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseSeries'
assert isinstance(img2, ci.Cifti2Image)
Expand Down Expand Up @@ -282,7 +282,7 @@ def test_dlabel():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.dlabel.nii', infer_intent=True)
ci.save(img, 'test.dlabel.nii')
img2 = nib.load('test.dlabel.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseLabel'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -301,7 +301,7 @@ def test_dconn():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.dconn.nii', infer_intent=True)
ci.save(img, 'test.dconn.nii')
img2 = nib.load('test.dconn.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDense'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -322,7 +322,7 @@ def test_ptseries():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.ptseries.nii', infer_intent=True)
ci.save(img, 'test.ptseries.nii')
img2 = nib.load('test.ptseries.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnParcelSries'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -343,7 +343,7 @@ def test_pscalar():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.pscalar.nii', infer_intent=True)
ci.save(img, 'test.pscalar.nii')
img2 = nib.load('test.pscalar.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnParcelScalr'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -364,7 +364,7 @@ def test_pdconn():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.pdconn.nii', infer_intent=True)
ci.save(img, 'test.pdconn.nii')
img2 = ci.load('test.pdconn.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnParcelDense'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -385,7 +385,7 @@ def test_dpconn():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.dpconn.nii', infer_intent=True)
ci.save(img, 'test.dpconn.nii')
img2 = ci.load('test.dpconn.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnDenseParcel'
assert isinstance(img2, ci.Cifti2Image)
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_pconn():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.pconn.nii', infer_intent=True)
ci.save(img, 'test.pconn.nii')
img2 = ci.load('test.pconn.nii')
assert img.nifti_header.get_intent()[0] == 'ConnParcels'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -447,7 +447,7 @@ def test_pconnseries():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.pconnseries.nii', infer_intent=True)
ci.save(img, 'test.pconnseries.nii')
img2 = ci.load('test.pconnseries.nii')
assert img.nifti_header.get_intent()[0] == 'ConnPPSr'
assert isinstance(img2, ci.Cifti2Image)
Expand All @@ -470,7 +470,7 @@ def test_pconnscalar():
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
ci.save(img, 'test.pconnscalar.nii', infer_intent=True)
ci.save(img, 'test.pconnscalar.nii')
img2 = ci.load('test.pconnscalar.nii')
assert img.nifti_header.get_intent()[0] == 'ConnPPSc'
assert isinstance(img2, ci.Cifti2Image)
Expand Down Expand Up @@ -509,3 +509,45 @@ def test_wrong_shape():
with pytest.raises(ValueError):
img.to_file_map()


def test_cifti_validation():
# flip label / brain_model index maps
geometry_map = create_geometry_map((0, ))
label_map = create_label_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(10, 2)
img = ci.Cifti2Image(data, hdr)

# attempt to save and validate with an invalid extension
with pytest.raises(KeyError):
ci.save(img, 'test.dlabelz.nii')
# even with a proper extension, flipped index maps will fail
with pytest.raises(ci.Cifti2HeaderError):
ci.save(img, 'test.dlabel.nii')

label_map = create_label_map((0, ))
geometry_map = create_geometry_map((1, ))
matrix = ci.Cifti2Matrix()
matrix.append(label_map)
matrix.append(geometry_map)
hdr = ci.Cifti2Header(matrix)
data = np.random.randn(2, 10)
img = ci.Cifti2Image(data, hdr)

with InTemporaryDirectory():
# still fail with invalid extension and validation
with pytest.raises(KeyError):
ci.save(img, 'test.dlabelz.nii')
# but removing validation should work (though intent code will be unknown)
ci.save(img, 'test.dlabelz.nii', validate=False)

img2 = nib.load('test.dlabelz.nii')
assert img2.nifti_header.get_intent()[0] == 'ConnUnknown'
assert isinstance(img2, ci.Cifti2Image)
assert_array_equal(img2.get_fdata(), data)
check_label_map(img2.header.matrix.get_index_map(0))
check_geometry_map(img2.header.matrix.get_index_map(1))
del img2

0 comments on commit b55a363

Please sign in to comment.