Skip to content

Commit

Permalink
Merge pull request #401 from mgxd/maint/update-parser
Browse files Browse the repository at this point in the history
MAINT: Port over parser and tests from fmriprep
  • Loading branch information
mgxd authored Sep 27, 2024
2 parents 6e48832 + 8162d10 commit 6d778c4
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 13 deletions.
48 changes: 35 additions & 13 deletions nibabies/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,36 @@ def _drop_sub(value):
def _drop_ses(value):
return value[4:] if value.startswith('ses-') else value

def _filter_pybids_none_any(dct):
def _process_value(value):
import bids

return {
k: bids.layout.Query.NONE if v is None else (bids.layout.Query.ANY if v == '*' else v)
for k, v in dct.items()
}

def _bids_filter(value):
from json import loads
if value is None:
return bids.layout.Query.NONE
elif value == '*':
return bids.layout.Query.ANY
else:
return value

if value and Path(value).exists():
return loads(Path(value).read_text(), object_hook=_filter_pybids_none_any)
def _filter_pybids_none_any(dct):
d = {}
for k, v in dct.items():
if isinstance(v, list):
d[k] = [_process_value(val) for val in v]
else:
d[k] = _process_value(v)
return d

def _bids_filter(value, parser):
from json import JSONDecodeError, loads

if value:
if Path(value).exists():
try:
return loads(Path(value).read_text(), object_hook=_filter_pybids_none_any)
except JSONDecodeError as e:
raise parser.error(f'JSON syntax error in: <{value}>.') from e
else:
raise parser.error(f'Path does not exist: <{value}>.')

def _slice_time_ref(value, parser):
if value == 'start':
Expand Down Expand Up @@ -142,6 +159,7 @@ def _slice_time_ref(value, parser):
DirNotEmpty = partial(_dir_not_empty, parser=parser)
IsFile = partial(_is_file, parser=parser)
PositiveInt = partial(_min_one, parser=parser)
BIDSFilter = partial(_bids_filter, parser=parser)
SliceTimeRef = partial(_slice_time_ref, parser=parser)

parser.description = f"""
Expand Down Expand Up @@ -214,7 +232,7 @@ def _slice_time_ref(value, parser):
'--bids-filter-file',
dest='bids_filters',
action='store',
type=_bids_filter,
type=BIDSFilter,
metavar='FILE',
help='a JSON file describing custom BIDS input filters using PyBIDS. '
'For further details, please check out '
Expand Down Expand Up @@ -546,9 +564,13 @@ def _slice_time_ref(value, parser):
g_syn = parser.add_argument_group('Specific options for SyN distortion correction')
g_syn.add_argument(
'--use-syn-sdc',
action='store_true',
nargs='?',
choices=['warn', 'error'],
action='store',
const='error',
default=False,
help='EXPERIMENTAL: Use fieldmap-free distortion correction',
help='Use fieldmap-less distortion correction based on anatomical image; '
'if unable, error (default) or warn based on optional argument.',
)
g_syn.add_argument(
'--force-syn',
Expand Down
Empty file added nibabies/cli/tests/__init__.py
Empty file.
272 changes: 272 additions & 0 deletions nibabies/cli/tests/test_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
"""Test parser."""

from argparse import ArgumentError
from contextlib import nullcontext

import pytest
from packaging.version import Version

from nibabies import config
from nibabies.cli import version as _version
from nibabies.cli.parser import _build_parser, parse_args
from nibabies.tests.test_config import _reset_config

MIN_ARGS = ['data/', 'out/', 'participant']


@pytest.mark.parametrize(
('args', 'code'),
[
([], 2),
(MIN_ARGS, 2), # bids_dir does not exist
(MIN_ARGS + ['--fs-license-file'], 2),
(MIN_ARGS + ['--fs-license-file', 'fslicense.txt'], 2),
],
)
def test_parser_errors(args, code):
"""Check behavior of the parser."""
with pytest.raises(SystemExit) as error:
_build_parser().parse_args(args)

assert error.value.code == code


@pytest.mark.parametrize('args', [MIN_ARGS, MIN_ARGS + ['--fs-license-file']])
def test_parser_valid(tmp_path, args):
"""Check valid arguments."""
datapath = tmp_path / 'data'
datapath.mkdir(exist_ok=True)
args[0] = str(datapath)

if '--fs-license-file' in args:
_fs_file = tmp_path / 'license.txt'
_fs_file.write_text('')
args.insert(args.index('--fs-license-file') + 1, str(_fs_file.absolute()))

opts = _build_parser().parse_args(args)

assert opts.bids_dir == datapath


@pytest.mark.parametrize(
('argval', 'gb'),
[
('1G', 1),
('1GB', 1),
('1000', 1), # Default units are MB
('32000', 32), # Default units are MB
('4000', 4), # Default units are MB
('1000M', 1),
('1000MB', 1),
('1T', 1000),
('1TB', 1000),
('%dK' % 1e6, 1),
('%dKB' % 1e6, 1),
('%dB' % 1e9, 1),
],
)
def test_memory_arg(tmp_path, argval, gb):
"""Check the correct parsing of the memory argument."""
datapath = tmp_path / 'data'
datapath.mkdir(exist_ok=True)
_fs_file = tmp_path / 'license.txt'
_fs_file.write_text('')

args = [str(datapath)] + MIN_ARGS[1:] + ['--fs-license-file', str(_fs_file), '--mem', argval]
opts = _build_parser().parse_args(args)

assert opts.memory_gb == gb


@pytest.mark.parametrize(('current', 'latest'), [('1.0.0', '1.3.2'), ('1.3.2', '1.3.2')])
def test_get_parser_update(monkeypatch, capsys, current, latest):
"""Make sure the out-of-date banner is shown."""
expectation = Version(current) < Version(latest)

def _mock_check_latest(*args, **kwargs):
return Version(latest)

monkeypatch.setattr(config.environment, 'version', current)
monkeypatch.setattr(_version, 'check_latest', _mock_check_latest)

_build_parser()
captured = capsys.readouterr().err

msg = f"""\
You are using NiBabies-{current}, and a newer version of NiBabies is available: {latest}.
Please check out our documentation about how and when to upgrade:
https://fmriprep.readthedocs.io/en/latest/faq.html#upgrading"""

assert (msg in captured) is expectation


@pytest.mark.parametrize('flagged', [(True, None), (True, 'random reason'), (False, None)])
def test_get_parser_blacklist(monkeypatch, capsys, flagged):
"""Make sure the blacklisting banner is shown."""

def _mock_is_bl(*args, **kwargs):
return flagged

monkeypatch.setattr(_version, 'is_flagged', _mock_is_bl)

_build_parser()
captured = capsys.readouterr().err

assert ('FLAGGED' in captured) is flagged[0]
if flagged[0]:
assert (flagged[1] or 'reason: unknown') in captured


def test_parse_args(tmp_path, minimal_bids):
"""Basic smoke test showing that our parse_args() function
implements the BIDS App protocol"""
out_dir = tmp_path / 'out'
work_dir = tmp_path / 'work'

parse_args(
args=[
str(minimal_bids),
str(out_dir),
'participant', # BIDS App
'-w',
str(work_dir), # Don't pollute CWD
'--skip-bids-validation', # Empty files make BIDS sad
]
)
assert config.execution.layout.root == str(minimal_bids)
_reset_config()


def test_bids_filter_file(tmp_path, capsys):
bids_path = tmp_path / 'data'
out_path = tmp_path / 'out'
bff = tmp_path / 'filter.json'
args = [str(bids_path), str(out_path), 'participant', '--bids-filter-file', str(bff)]
bids_path.mkdir()

parser = _build_parser()

with pytest.raises(SystemExit):
parser.parse_args(args)

err = capsys.readouterr().err
assert 'Path does not exist:' in err

bff.write_text('{"invalid json": }')

with pytest.raises(SystemExit):
parser.parse_args(args)

err = capsys.readouterr().err
assert 'JSON syntax error in:' in err
_reset_config()


@pytest.mark.parametrize('st_ref', (None, '0', '1', '0.5', 'start', 'middle')) # noqa: PT007
def test_slice_time_ref(tmp_path, st_ref):
bids_path = tmp_path / 'data'
out_path = tmp_path / 'out'
args = [str(bids_path), str(out_path), 'participant']
if st_ref:
args.extend(['--slice-time-ref', st_ref])
bids_path.mkdir()

parser = _build_parser()

parser.parse_args(args)
_reset_config()


@pytest.mark.parametrize(
('args', 'expectation'),
[
([], False),
(['--use-syn-sdc'], 'error'),
(['--use-syn-sdc', 'error'], 'error'),
(['--use-syn-sdc', 'warn'], 'warn'),
(['--use-syn-sdc', 'other'], (SystemExit, ArgumentError)),
],
)
def test_use_syn_sdc(tmp_path, args, expectation):
bids_path = tmp_path / 'data'
out_path = tmp_path / 'out'
args = [str(bids_path), str(out_path), 'participant'] + args
bids_path.mkdir()

parser = _build_parser()

cm = nullcontext()
if isinstance(expectation, tuple):
cm = pytest.raises(expectation)

with cm:
opts = parser.parse_args(args)

if not isinstance(expectation, tuple):
assert opts.use_syn_sdc == expectation

_reset_config()


def test_derivatives(tmp_path):
"""Check the correct parsing of the derivatives argument."""
bids_path = tmp_path / 'data'
out_path = tmp_path / 'out'
args = [str(bids_path), str(out_path), 'participant']
bids_path.mkdir()

parser = _build_parser()

# Providing --derivatives without a path should raise an error
temp_args = args + ['--derivatives']
with pytest.raises((SystemExit, ArgumentError)):
parser.parse_args(temp_args)
_reset_config()

# Providing --derivatives without names should automatically label them
temp_args = args + ['--derivatives', str(bids_path / 'derivatives/smriprep')]
opts = parser.parse_args(temp_args)
assert opts.derivatives == {'smriprep': bids_path / 'derivatives/smriprep'}
_reset_config()

# Providing --derivatives with names should use them
temp_args = args + [
'--derivatives',
f'anat={str(bids_path / "derivatives/smriprep")}',
]
opts = parser.parse_args(temp_args)
assert opts.derivatives == {'anat': bids_path / 'derivatives/smriprep'}
_reset_config()

# Providing multiple unlabeled derivatives with the same name should raise an error
temp_args = args + [
'--derivatives',
str(bids_path / 'derivatives_01/smriprep'),
str(bids_path / 'derivatives_02/smriprep'),
]
with pytest.raises(ValueError, match='Received duplicate derivative name'):
parser.parse_args(temp_args)

_reset_config()
16 changes: 16 additions & 0 deletions nibabies/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""py.test configuration"""

import json
from pathlib import Path
from tempfile import TemporaryDirectory

import nibabel as nb
import numpy as np
import pytest

from nibabies.data import load as load_data
Expand Down Expand Up @@ -37,3 +40,16 @@ def _populate_namespace(doctest_namespace, data_dir):
doctest_namespace['data_dir'] = data_dir
doctest_namespace['test_data'] = load_data.cached('../tests/data')
doctest_namespace['Path'] = Path


@pytest.fixture
def minimal_bids(tmp_path):
bids = tmp_path / 'bids'
bids.mkdir()
Path.write_text(
bids / 'dataset_description.json', json.dumps({'Name': 'Test DS', 'BIDSVersion': '1.8.0'})
)
T1w = bids / 'sub-01' / 'anat' / 'sub-01_T1w.nii.gz'
T1w.parent.mkdir(parents=True)
nb.Nifti1Image(np.zeros((5, 5, 5)), np.eye(4)).to_filename(T1w)
return bids

0 comments on commit 6d778c4

Please sign in to comment.