diff --git a/.copier-answers.yml b/.copier-answers.yml new file mode 100644 index 0000000..b42deef --- /dev/null +++ b/.copier-answers.yml @@ -0,0 +1,12 @@ +# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY +_commit: 2024.04.23 +_src_path: gh:scientific-python/cookie +backend: hatch +email: f.isensee@dkfz.de +full_name: Fabian Isensee +license: Apache +org: MIC-DKFZ +project_name: HD_BET +project_short_description: Tool for brain extraction +url: https://github.com/MIC-DKFZ/HD_BET +vcs: false diff --git a/.git_archival.txt b/.git_archival.txt new file mode 100644 index 0000000..8fb235d --- /dev/null +++ b/.git_archival.txt @@ -0,0 +1,4 @@ +node: $Format:%H$ +node-date: $Format:%cI$ +describe-name: $Format:%(describe:tags=true,match=*[0-9]*)$ +ref-names: $Format:%D$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..00a7b00 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.git_archival.txt export-subst diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..a7f4456 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,101 @@ +See the [Scientific Python Developer Guide][spc-dev-intro] for a detailed +description of best practices for developing scientific packages. + +[spc-dev-intro]: https://learn.scientific-python.org/development/ + +# Quick development + +The fastest way to start with development is to use nox. If you don't have nox, +you can use `pipx run nox` to run it without installing, or `pipx install nox`. +If you don't have pipx (pip for applications), then you can install with +`pip install pipx` (the only case were installing an application with regular +pip is reasonable). If you use macOS, then pipx and nox are both in brew, use +`brew install pipx nox`. + +To use, run `nox`. This will lint and test using every installed version of +Python on your system, skipping ones that are not installed. You can also run +specific jobs: + +```console +$ nox -s lint # Lint only +$ nox -s tests # Python tests +$ nox -s docs -- --serve # Build and serve the docs +$ nox -s build # Make an SDist and wheel +``` + +Nox handles everything for you, including setting up an temporary virtual +environment for each run. + +# Setting up a development environment manually + +You can set up a development environment by running: + +```bash +python3 -m venv .venv +source ./.venv/bin/activate +pip install -v -e .[dev] +``` + +If you have the +[Python Launcher for Unix](https://github.com/brettcannon/python-launcher), you +can instead do: + +```bash +py -m venv .venv +py -m install -v -e .[dev] +``` + +# Post setup + +You should prepare pre-commit, which will help you by checking that commits pass +required checks: + +```bash +pip install pre-commit # or brew install pre-commit on macOS +pre-commit install # Will install a pre-commit hook into the git repo +``` + +You can also/alternatively run `pre-commit run` (changes only) or +`pre-commit run --all-files` to check even without installing the hook. + +# Testing + +Use pytest to run the unit checks: + +```bash +pytest +``` + +# Coverage + +Use pytest-cov to generate coverage reports: + +```bash +pytest --cov=HD_BET +``` + +# Building docs + +You can build the docs using: + +```bash +nox -s docs +``` + +You can see a preview with: + +```bash +nox -s docs -- --serve +``` + +# Pre-commit + +This project uses pre-commit for all style checking. While you can run it with +nox, this is such an important tool that it deserves to be installed on its own. +Install pre-commit and run: + +```bash +pre-commit run -a +``` + +to check all files. diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..6c4b369 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + groups: + actions: + patterns: + - "*" diff --git a/.github/matchers/pylint.json b/.github/matchers/pylint.json new file mode 100644 index 0000000..e3a6bd1 --- /dev/null +++ b/.github/matchers/pylint.json @@ -0,0 +1,32 @@ +{ + "problemMatcher": [ + { + "severity": "warning", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): ([A-DF-Z]\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-warning" + }, + { + "severity": "error", + "pattern": [ + { + "regexp": "^([^:]+):(\\d+):(\\d+): (E\\d+): \\033\\[[\\d;]+m([^\\033]+).*$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ], + "owner": "pylint-error" + } + ] +} diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..965645a --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,50 @@ +name: CD + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + release: + types: + - published + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + # Many color libraries just need this to be set to any value, but at least + # one distinguishes color depth, where "3" -> "256-bit color". + FORCE_COLOR: 3 + +jobs: + dist: + name: Distribution build + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: hynek/build-and-inspect-python-package@v2 + + publish: + needs: [dist] + name: Publish to PyPI + environment: pypi + permissions: + id-token: write + runs-on: ubuntu-latest + if: github.event_name == 'release' && github.event.action == 'published' + + steps: + - uses: actions/download-artifact@v4 + with: + name: Packages + path: dist + + - uses: pypa/gh-action-pypi-publish@release/v1 + if: github.event_name == 'release' && github.event.action == 'published' diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..1fb4810 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,69 @@ +name: CI + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + # Many color libraries just need this to be set to any value, but at least + # one distinguishes color depth, where "3" -> "256-bit color". + FORCE_COLOR: 3 + +jobs: + pre-commit: + name: Format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --hook-stage manual --all-files + - name: Run PyLint + run: | + echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json" + pipx run nox -s pylint + + checks: + name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} + runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.12"] + runs-on: [ubuntu-latest, macos-latest, windows-latest] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + allow-prereleases: true + + - name: Install package + run: python -m pip install .[test] + + - name: Test package + run: >- + python -m pytest -ra --cov --cov-report=xml --cov-report=term + --durations=20 + + - name: Upload coverage report + uses: codecov/codecov-action@v4.3.0 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 8d82ac6..519ad1d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,9 +20,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -37,13 +40,17 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo @@ -52,6 +59,8 @@ coverage.xml # Django stuff: *.log local_settings.py +db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -64,30 +73,92 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ -# IPython Notebook +# Jupyter Notebook .ipynb_checkpoints +# IPython +profile_default/ +ipython_config.py + # pyenv .python-version -# celery beat schedule file +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid -# dotenv -.env +# SageMath parsed files +*.sage.py -# virtualenv +# Environments +.env +.venv +env/ venv/ ENV/ +env.bak/ +venv.bak/ # Spyder project settings .spyderproject +.spyproject # Rope project settings .ropeproject + + +# mkdocs documentation +/site + + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# setuptools_scm +src/*/_version.py + + +# ruff +.ruff_cache/ + +# OS specific stuff +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Common editor files +*~ +*.swp + *.memmap *.png *.zip @@ -110,4 +181,4 @@ ENV/ *.jpg *.jpeg -*.model \ No newline at end of file +*.model diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..abc0d65 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,73 @@ +ci: + autoupdate_commit_msg: "chore: update pre-commit hooks" + autofix_commit_msg: "style: pre-commit fixes" + +repos: + - repo: https://github.com/adamchainz/blacken-docs + rev: "1.16.0" + hooks: + - id: blacken-docs + additional_dependencies: [black==24.*] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.6.0" + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: name-tests-test + args: ["--pytest-test-first"] + - id: requirements-txt-fixer + - id: trailing-whitespace + + - repo: https://github.com/pre-commit/pygrep-hooks + rev: "v1.10.0" + hooks: + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + + - repo: https://github.com/pre-commit/mirrors-prettier + rev: "v3.1.0" + hooks: + - id: prettier + types_or: [yaml, markdown, html, css, scss, javascript, json] + args: [--prose-wrap=always] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.4.1" + hooks: + - id: ruff + args: ["--fix", "--show-fixes"] + - id: ruff-format + + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: "v0.10.0.1" + hooks: + - id: shellcheck + + - repo: local + hooks: + - id: disallow-caps + name: Disallow improper capitalization + language: pygrep + entry: PyBind|Numpy|Cmake|CCache|Github|PyTest + exclude: .pre-commit-config.yaml + + - repo: https://github.com/abravalheri/validate-pyproject + rev: "v0.16" + hooks: + - id: validate-pyproject + additional_dependencies: ["validate-pyproject-schema-store[all]"] + + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: "0.28.2" + hooks: + - id: check-dependabot + - id: check-github-workflows + - id: check-readthedocs diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..7e49657 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,18 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/HD_BET/hd-bet b/HD_BET/hd-bet deleted file mode 100755 index bfd79fa..0000000 --- a/HD_BET/hd-bet +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python - -import os -from HD_BET.run import run_hd_bet -from HD_BET.utils import maybe_mkdir_p, subfiles -import HD_BET - - -if __name__ == "__main__": - print("\n########################") - print("If you are using hd-bet, please cite the following paper:") - print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," - "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" - "neural networks. arXiv preprint arXiv:1901.11341, 2019.") - print("########################\n") - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be ' - 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to ' - 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz ' - 'within that folder will be brain extracted.', required=True, type=str) - parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder' - ' will be created', required=False, type=str) - parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will ' - 'use only one set of parameters whereas accurate will ' - 'use the five sets of parameters that resulted from ' - 'our cross-validation as an ensemble. Default: ' - 'accurate', - required=False) - parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. ' - 'Must be either int or str. Use int for GPU id or ' - '\'cpu\' to run on CPU. When using CPU you should ' - 'consider disabling tta. Default for -device is: 0', - required=False) - parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation ' - '(mirroring). 1= True, 0=False. Disable this ' - 'if you are using CPU to speed things up! ' - 'Default: 1') - parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all' - ' but the largest connected component in ' - 'the prediction. Default: 1') - parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation ' - 'mask will not be ' - 'saved') - parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't " - "want to overwrite existing " - "predictions") - parser.add_argument('-b','--bet', default=1, type=int, required=False, help="set this to 0 if you don't want to save skull-stripped brain") - - args = parser.parse_args() - - input_file_or_dir = args.input - output_file_or_dir = args.output - - if output_file_or_dir is None: - output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir), - os.path.basename(input_file_or_dir).split(".")[0] + "_bet") - - mode = args.mode - device = args.device - tta = args.tta - pp = args.pp - save_mask = args.save_mask - overwrite_existing = args.overwrite_existing - bet = args.bet - - params_file = os.path.join(HD_BET.__path__[0], "model_final.py") - config_file = os.path.join(HD_BET.__path__[0], "config.py") - - assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" - - if device == 'cpu': - pass - else: - device = int(device) - - if os.path.isdir(input_file_or_dir): - maybe_mkdir_p(output_file_or_dir) - input_files = subfiles(input_file_or_dir, suffix='.nii.gz', join=False) - - if len(input_files) == 0: - raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here") - - output_files = [os.path.join(output_file_or_dir, i) for i in input_files] - input_files = [os.path.join(input_file_or_dir, i) for i in input_files] - else: - if not output_file_or_dir.endswith('.nii.gz'): - output_file_or_dir += '.nii.gz' - assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input" - - output_files = [output_file_or_dir] - input_files = [input_file_or_dir] - - if tta == 0: - tta = False - elif tta == 1: - tta = True - else: - raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta)) - - if overwrite_existing == 0: - overwrite_existing = False - elif overwrite_existing == 1: - overwrite_existing = True - else: - raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing)) - - if pp == 0: - pp = False - elif pp == 1: - pp = True - else: - raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) - - if save_mask == 0: - save_mask = False - elif save_mask == 1: - save_mask = True - else: - raise ValueError("Unknown value for save_mask: %s. Expected: 0 or 1" % str(save_mask)) - - if bet == 0: - if save_mask: - bet = False - else: - print("Save_mask and bet are set to 0. In this case, Bet is set to 1.") - bet = True - elif bet == 1: - bet = True - else: - raise ValueError("Unknown value for bet: %s. Expected: 0 or 1" % str(pp)) - - run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing, bet) diff --git a/HD_BET/network_architecture.py b/HD_BET/network_architecture.py deleted file mode 100755 index 0824aa1..0000000 --- a/HD_BET/network_architecture.py +++ /dev/null @@ -1,213 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from HD_BET.utils import softmax_helper - - -class EncodingModule(nn.Module): - def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True): - nn.Module.__init__(self) - self.dropout_p = dropout_p - self.lrelu_inplace = lrelu_inplace - self.inst_norm_affine = inst_norm_affine - self.conv_bias = conv_bias - self.leakiness = leakiness - self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) - self.dropout = nn.Dropout3d(dropout_p) - self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias) - - def forward(self, x): - skip = x - x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - x = self.conv1(x) - if self.dropout_p is not None and self.dropout_p > 0: - x = self.dropout(x) - x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - x = self.conv2(x) - x = x + skip - return x - - -class Upsample(nn.Module): - def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True): - super(Upsample, self).__init__() - self.align_corners = align_corners - self.mode = mode - self.scale_factor = scale_factor - self.size = size - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, - align_corners=self.align_corners) - - -class LocalizationModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): - nn.Module.__init__(self) - self.lrelu_inplace = lrelu_inplace - self.inst_norm_affine = inst_norm_affine - self.conv_bias = conv_bias - self.leakiness = leakiness - self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) - self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) - self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) - - def forward(self, x): - x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - return x - - -class UpsamplingModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): - nn.Module.__init__(self) - self.lrelu_inplace = lrelu_inplace - self.inst_norm_affine = inst_norm_affine - self.conv_bias = conv_bias - self.leakiness = leakiness - self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) - self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias) - self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True) - - def forward(self, x): - x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) - return x - - -class DownsamplingModule(nn.Module): - def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True): - nn.Module.__init__(self) - self.lrelu_inplace = lrelu_inplace - self.inst_norm_affine = inst_norm_affine - self.conv_bias = conv_bias - self.leakiness = leakiness - self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True) - self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias) - - def forward(self, x): - x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - b = self.downsample(x) - return x, b - - -class Network(nn.Module): - def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3, - final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True, - lrelu_inplace=True, do_ds=True): - super(Network, self).__init__() - - self.do_ds = do_ds - self.lrelu_inplace = lrelu_inplace - self.inst_norm_affine = inst_norm_affine - self.conv_bias = conv_bias - self.leakiness = leakiness - self.final_nonlin = final_nonlin - self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias) - - self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2, - conv_bias=True, inst_norm_affine=True, lrelu_inplace=True) - - self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) - self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) - self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) - self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True, - inst_norm_affine=True, lrelu_inplace=True) - - self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) - self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) - self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias) - self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True) - self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) - - def forward(self, x): - seg_outputs = [] - - x = self.init_conv(x) - x = self.context1(x) - - skip1, x = self.down1(x) - x = self.context2(x) - - skip2, x = self.down2(x) - x = self.context3(x) - - skip3, x = self.down3(x) - x = self.context4(x) - - skip4, x = self.down4(x) - x = self.context5(x) - - x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace) - x = self.up1(x) - - x = torch.cat((skip4, x), dim=1) - x = self.loc1(x) - x = self.up2(x) - - x = torch.cat((skip3, x), dim=1) - x = self.loc2(x) - loc2_seg = self.final_nonlin(self.loc2_seg(x)) - seg_outputs.append(loc2_seg) - x = self.up3(x) - - x = torch.cat((skip2, x), dim=1) - x = self.loc3(x) - loc3_seg = self.final_nonlin(self.loc3_seg(x)) - seg_outputs.append(loc3_seg) - x = self.up4(x) - - x = torch.cat((skip1, x), dim=1) - x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) - x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness, - inplace=self.lrelu_inplace) - x = self.final_nonlin(self.seg_layer(x)) - seg_outputs.append(x) - - if self.do_ds: - return seg_outputs[::-1] - else: - return seg_outputs[-1] diff --git a/HD_BET/paths.py b/HD_BET/paths.py deleted file mode 100644 index 13b2e65..0000000 --- a/HD_BET/paths.py +++ /dev/null @@ -1,4 +0,0 @@ -import os - -# please refer to the readme on where to get the parameters. Save them in this folder: -folder_with_parameter_files = os.path.join(os.path.expanduser('~'), 'hd-bet_params') diff --git a/LICENSE b/LICENSE index 9c8f3ea..8dada3e 100644 --- a/LICENSE +++ b/LICENSE @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..1969737 --- /dev/null +++ b/README.md @@ -0,0 +1,176 @@ +# HD_BET + +[![Actions Status][actions-badge]][actions-link] +[![Documentation Status][rtd-badge]][rtd-link] + +[![PyPI version][pypi-version]][pypi-link] +[![Conda-Forge][conda-badge]][conda-link] +[![PyPI platforms][pypi-platforms]][pypi-link] + +[![GitHub Discussion][github-discussions-badge]][github-discussions-link] + + + + +[actions-badge]: https://github.com/MIC-DKFZ/HD_BET/workflows/CI/badge.svg +[actions-link]: https://github.com/MIC-DKFZ/HD_BET/actions +[conda-badge]: https://img.shields.io/conda/vn/conda-forge/HD_BET +[conda-link]: https://github.com/conda-forge/HD_BET-feedstock +[github-discussions-badge]: https://img.shields.io/static/v1?label=Discussions&message=Ask&color=blue&logo=github +[github-discussions-link]: https://github.com/MIC-DKFZ/HD_BET/discussions +[pypi-link]: https://pypi.org/project/HD_BET/ +[pypi-platforms]: https://img.shields.io/pypi/pyversions/HD_BET +[pypi-version]: https://img.shields.io/pypi/v/HD_BET +[rtd-badge]: https://readthedocs.org/projects/HD_BET/badge/?version=latest +[rtd-link]: https://HD_BET.readthedocs.io/en/latest/?badge=latest + + + +This repository provides easy to use access to our recently published HD-BET +brain extraction tool. HD-BET is the result of a joint project between the +Department of Neuroradiology at the Heidelberg University Hospital and the +Division of Medical Image Computing at the German Cancer Research Center (DKFZ). + +If you are using HD-BET, please cite the following publication: + +Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, +Schlemmer HP, Heiland S, Wick W, Bendszus M, Maier-Hein KH, Kickingereder P. +Automated brain extraction of multi-sequence MRI using artificial neural +networks. Hum Brain Mapp. 2019; 1–13. https://doi.org/10.1002/hbm.24750 + +Compared to other commonly used brain extraction tools, HD-BET has some +significant advantages: + +- HD-BET was developed with MRI-data from a large multicentric clinical trial in + adult brain tumor patients acquired across 37 institutions in Europe and + included a broad range of MR hardware and acquisition parameters, pathologies + or treatment-induced tissue alterations. We used 2/3 of data for training and + validation and 1/3 for testing. Moreover independent testing of HD-BET was + performed in three public benchmark datasets (NFBS, LPBA40 and CC-359). +- HD-BET was trained with precontrast T1-w, postcontrast T1-w, T2-w and FLAIR + sequences. It can perform independent brain extraction on various different + MRI sequences and is not restricted to precontrast T1-weighted (T1-w) + sequences. Other MRI sequences may work as well (just give it a try!) +- HD-BET was designed to be robust with respect to brain tumors, lesions and + resection cavities as well as different MRI scanner hardware and acquisition + parameters. +- HD-BET outperformed five publicly available brain extraction algorithms (FSL + BET, AFNI 3DSkullStrip, Brainsuite BSE, ROBEX and BEaST) across all datasets + and yielded median improvements of +1.33 to +2.63 points for the DICE + coefficient and -0.80 to -2.75 mm for the Hausdorff distance + (Bonferroni-adjusted p<0.001). +- HD-BET is very fast on GPU with <10s run time per MRI sequence. Even on CPU it + is not slower than other commonly used tools. + +## Installation Instructions + +Note that you need to have a python3 installation for HD-BET to work. Please +also make sure to install HD-BET with the correct pip version (the one that is +connected to python3). You can verify this using the `--version` command: + +``` +(dl_venv) fabian@Fabian:~$ pip --version +pip 20.0.2 from /home/fabian/dl_venv/lib/python3.6/site-packages/pip (python 3.6) +``` + +If it does not show python 3.X, you can try pip3. If that also does not work you +probably need to install python3 first. + +Once python 3 and pip are set up correctly, run the following commands to +install HD-BET: + +1. Clone this repository: + ```bash + git clone https://github.com/MIC-DKFZ/HD-BET + ``` +2. Go into the repository (the folder with the setup.py file) and install: + ``` + cd HD-BET + pip install -e . + ``` +3. Per default, model parameters will be downloaded to ~/hd-bet_params. If you + wish to use a different folder, open HD_BET/paths.py in a text editor and + modify `folder_with_parameter_files` + +## How to use it + +Using HD_BET is straightforward. You can use it in any terminal on your linux +system. The hd-bet command was installed automatically. We provide CPU as well +as GPU support. Running on GPU is a lot faster though and should always be +preferred. Here is a minimalistic example of how you can use HD-BET (you need to +be in the HD_BET directory) + +```bash +hd-bet -i INPUT_FILENAME +``` + +INPUT_FILENAME must be a nifti (.nii.gz) file containing 3D MRI image data. 4D +image sequences are not supported (however can be splitted upfront into the +individual temporal volumes using fslsplit1). INPUT_FILENAME can be +either a pre- or postcontrast T1-w, T2-w or FLAIR MRI sequence. Other modalities +might work as well. Input images must match the orientation of standard MNI152 +template! Use fslreorient2std 2 upfront to ensure that this is the +case. + +By default, HD-BET will run in GPU mode, use the parameters of all five models +(which originate from a five-fold cross-validation), use test time data +augmentation by mirroring along all axes and not do any postprocessing. + +For batch processing it is faster to process an entire folder at once as this +will mitigate the overhead of loading and initializing the model for each case: + +```bash +hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER +``` + +The above command will look for all nifti files (\*.nii.gz) in the INPUT_FOLDER +and save the brain masks under the same name in OUTPUT_FOLDER. + +### GPU is nice, but I don't have one of those... What now? + +HD-BET has CPU support. Running on CPU takes a lot longer though and you will +need quite a bit of RAM. To run on CPU, we recommend you use the following +command: + +```bash +hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER -device cpu -mode fast -tta 0 +``` + +This works of course also with just an input file: + +```bash +hd-bet -i INPUT_FILENAME -device cpu -mode fast -tta 0 +``` + +The options _-mode fast_ and _-tta 0_ will disable test time data augmentation +(speedup of 8x) and use only one model instead of an ensemble of five models for +the prediction. + +### More options: + +For more information, please refer to the help functionality: + +```bash +hd-bet --help +``` + +## FAQ + +1. **How much GPU memory do I need to run HD-BET?** We ran all our experiments + on NVIDIA Titan X GPUs with 12 GB memory. For inference you will need less, + but since inference in implemented by exploiting the fully convolutional + nature of CNNs the amount of memory required depends on your image. Typical + image should run with less than 4 GB of GPU memory consumption. If you run + into out of memory problems please check the following: 1) Make sure the + voxel spacing of your data is correct and 2) Ensure your MRI image only + contains the head region +2. **Will you provide the training code as well?** No. The training code is + tightly wound around the data which we cannot make public. +3. **What run time can I expect on CPU/GPU?** This depends on your MRI image + size. Typical run times (preprocessing, postprocessing and resampling + included) are just a couple of seconds for GPU and about 2 Minutes on CPU + (using `-tta 0 -mode fast`) + +1https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Fslutils + +2https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..49290f9 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import importlib.metadata + +project = "HD_BET" +copyright = "2024, Fabian Isensee" +author = "Fabian Isensee" +version = release = importlib.metadata.version("hd_bet") + +extensions = [ + "myst_parser", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", + "sphinx_copybutton", +] + +source_suffix = [".rst", ".md"] +exclude_patterns = [ + "_build", + "**.ipynb_checkpoints", + "Thumbs.db", + ".DS_Store", + ".env", + ".venv", +] + +html_theme = "furo" + +myst_enable_extensions = [ + "colon_fence", +] + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), +} + +nitpick_ignore = [ + ("py:class", "_io.StringIO"), + ("py:class", "_io.BytesIO"), +] + +always_document_param_types = True diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..8fd7c8b --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +# HD_BET + +```{toctree} +:maxdepth: 2 +:hidden: + +``` + +```{include} ../README.md +:start-after: +``` + +## Indices and tables + +- {ref}`genindex` +- {ref}`modindex` +- {ref}`search` diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000..0534a01 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import argparse +import shutil +from pathlib import Path + +import nox + +DIR = Path(__file__).parent.resolve() + +nox.needs_version = ">=2024.3.2" +nox.options.sessions = ["lint", "pylint", "tests"] +nox.options.default_venv_backend = "uv|virtualenv" + + +@nox.session +def lint(session: nox.Session) -> None: + """ + Run the linter. + """ + session.install("pre-commit") + session.run( + "pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs + ) + + +@nox.session +def pylint(session: nox.Session) -> None: + """ + Run PyLint. + """ + # This needs to be installed into the package environment, and is slower + # than a pre-commit check + session.install(".", "pylint") + session.run("pylint", "hd_bet", *session.posargs) + + +@nox.session +def tests(session: nox.Session) -> None: + """ + Run the unit and regular tests. + """ + session.install(".[test]") + session.run("pytest", *session.posargs) + + +@nox.session(reuse_venv=True) +def docs(session: nox.Session) -> None: + """ + Build the docs. Pass "--serve" to serve. Pass "-b linkcheck" to check links. + """ + + parser = argparse.ArgumentParser() + parser.add_argument("--serve", action="store_true", help="Serve after building") + parser.add_argument( + "-b", dest="builder", default="html", help="Build target (default: html)" + ) + args, posargs = parser.parse_known_args(session.posargs) + + if args.builder != "html" and args.serve: + session.error("Must not specify non-HTML builder with --serve") + + extra_installs = ["sphinx-autobuild"] if args.serve else [] + + session.install("-e.[docs]", *extra_installs) + session.chdir("docs") + + if args.builder == "linkcheck": + session.run( + "sphinx-build", "-b", "linkcheck", ".", "_build/linkcheck", *posargs + ) + return + + shared_args = ( + "-n", # nitpicky mode + "-T", # full tracebacks + f"-b={args.builder}", + ".", + f"_build/{args.builder}", + *posargs, + ) + + if args.serve: + session.run("sphinx-autobuild", *shared_args) + else: + session.run("sphinx-build", "--keep-going", *shared_args) + + +@nox.session +def build_api_docs(session: nox.Session) -> None: + """ + Build (regenerate) API docs. + """ + + session.install("sphinx") + session.chdir("docs") + session.run( + "sphinx-apidoc", + "-o", + "api/", + "--module-first", + "--no-toc", + "--force", + "../src/hd_bet", + ) + + +@nox.session +def build(session: nox.Session) -> None: + """ + Build an SDist and wheel. + """ + + build_path = DIR.joinpath("build") + if build_path.exists(): + shutil.rmtree(build_path) + + session.install("build") + session.run("python", "-m", "build") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ed2c1b2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,194 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[project] +name = "HD_BET" +authors = [ + { name = "Fabian Isensee", email = "f.isensee@dkfz.de" }, +] +description = "Tool for brain extraction" +readme = "README.md" +license.file = "LICENSE" +requires-python = ">=3.5" +classifiers = [ + "Development Status :: 1 - Planning", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: Unix", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Typing :: Typed", +] +dynamic = ["version"] +dependencies = [ + "numpy", + "torch>=0.4.1", + "scikit-image", + "SimpleITK", +] + +[project.optional-dependencies] +test = [ + "pytest >=6", + "pytest-cov >=3", +] +dev = [ + "pytest >=6", + "pytest-cov >=3", + "pre-commit", +] +docs = [ + "sphinx>=7.0", + "myst_parser>=0.13", + "sphinx_copybutton", + "sphinx_autodoc_typehints", + "furo>=2023.08.17", +] + +[project.urls] +Homepage = "https://github.com/MIC-DKFZ/HD_BET" +"Bug Tracker" = "https://github.com/MIC-DKFZ/HD_BET/issues" +Discussions = "https://github.com/MIC-DKFZ/HD_BET/discussions" +Changelog = "https://github.com/MIC-DKFZ/HD_BET/releases" + +[project.scripts] +hd-bet = "hd_bet.hd_bet_cli:main" + +[tool.hatch] +version.path = "src/hd_bet/__init__.py" + +[tool.hatch.envs.default] +features = ["test"] +scripts.test = "pytest {args}" + + +[tool.pytest.ini_options] +minversion = "6.0" +addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +xfail_strict = true +filterwarnings = [ + "error", +] +log_cli_level = "INFO" +testpaths = [ + "tests", +] + + +[tool.coverage] +run.source = ["hd_bet"] +report.exclude_also = [ + '\.\.\.', + 'if typing.TYPE_CHECKING:', +] + +[tool.ruff] +src = ["src"] + +[tool.ruff.lint] +extend-select = [ + "B", # flake8-bugbear + "I", # isort + "ARG", # flake8-unused-arguments + "C4", # flake8-comprehensions + "EM", # flake8-errmsg + "ICN", # flake8-import-conventions + "G", # flake8-logging-format + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "RET", # flake8-return + "RUF", # Ruff-specific + "SIM", # flake8-simplify + "T20", # flake8-print + "UP", # pyupgrade + "YTT", # flake8-2020 + "EXE", # flake8-executable + "NPY", # NumPy specific rules + "PD", # pandas-vet +] +ignore = [ + "PLR09", # Too many <...> + "PLR2004", # Magic value used in comparison + "ISC001", # Conflicts with formatter + "EXE002", # The file is executable but no shebang is present + "PTH118", # `os.path.join()` should be replaced by `Path` with `/` operator + "SIM108", # Use ternary operator `l = os.path.join if join else lambda x, y: y` instead of `if`-`else`-block + "E741", # Ambiguous variable name: `l` + "E731", # Do not assign a `lambda` expression, use a `def` + "EM101", # Exception must not use a string literal, assign to variable first + "ARG005", # Unused lambda argument: `x` + "RUF005", # Consider `[1, *list(new_shp)]` instead of concatenation + "ARG005", # Unused lambda argument: `loc` + "B008", # Do not perform function call `os.path.join` in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable + "PTH113", # `os.path.isfile()` should be replaced by `Path.is_file()` + "RET505", # Unnecessary `else` after `return` statement + "PTH120", # `os.path.dirname()` should be replaced by `Path.parent` + "F841", # Local variable `params_file` is assigned to but never used + "PTH112", # `os.path.isdir()` should be replaced by `Path.is_dir()` + "PTH118", # `os.path.join()` should be replaced by `Path.joinpath()` + "PTH102", # `os.mkdir()` should be replaced by `Path.mkdir()` + "C414", # Unnecessary `list` call within `tuple()` + "PTH123", # `open()` should be replaced by `Path.open()` + "PTH111", # `os.path.expanduser()` should be replaced by `Path.expanduser()` + "RET504", # Unnecessary assignment to `x` before `return` statement + "UP008", # Use `super()` instead of `super(__class__, self)` + "PTH119", # `os.path.basename()` should be replaced by `Path.name` + "C419", # Unnecessary list comprehension + "SIM108", # Use ternary operator `x = 8 if do_mirroring else 1` instead of `if`-`else`-block + "T201", # `print` found + "B007", # Loop control variable `i` not used within loop body + "PTH100", # `os.path.abspath()` should be replaced by `Path.resolve()` + "PTH107", # `os.remove()` should be replaced by `Path.unlink()` +] +isort.required-imports = ["from __future__ import annotations"] +# Uncomment if using a _compat.typing backport +# typing-modules = ["hd_bet._compat.typing"] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = ["T20"] +"noxfile.py" = ["T20"] + + +[tool.pylint] +py-version = "3.8" +ignore-paths = [".*/_version.py"] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "design", + "fixme", + "line-too-long", + "missing-module-docstring", + "missing-function-docstring", + "wrong-import-position", + "missing-class-docstring", + "invalid-name", + "import-error", + "consider-using-f-string", + "consider-using-with", + "unnecessary-lambda-assignment", + "super-with-arguments", + "arguments-renamed", + "attribute-defined-outside-init", + "no-member", + "deprecated-module", + "no-else-return", + "use-a-generator", + "consider-using-enumerate", + "superfluous-parens", + "unused-variable", + "import-outside-toplevel", +] diff --git a/readme.md b/readme.md deleted file mode 100644 index 58f0ead..0000000 --- a/readme.md +++ /dev/null @@ -1,125 +0,0 @@ -# HD-BET - -This repository provides easy to use access to our recently published HD-BET brain extraction tool. HD-BET is the result -of a joint project between the Department of Neuroradiology at the Heidelberg University Hospital and the -Division of Medical Image Computing at the German Cancer Research Center (DKFZ). - -If you are using HD-BET, please cite the following publication: - -Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W, -Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial neural -networks. Hum Brain Mapp. 2019; 1–13. https://doi.org/10.1002/hbm.24750 - -Compared to other commonly used brain extraction tools, HD-BET has some significant advantages: -- HD-BET was developed with MRI-data from a large multicentric clinical trial in adult brain tumor patients acquired -across 37 institutions in Europe and included a broad range of MR hardware and acquisition parameters, pathologies -or treatment-induced tissue alterations. We used 2/3 of data for training and validation and 1/3 for testing. -Moreover independent testing of HD-BET was performed in three public benchmark datasets (NFBS, LPBA40 and CC-359). -- HD-BET was trained with precontrast T1-w, postcontrast T1-w, T2-w and FLAIR sequences. It can perform independent -brain extraction on various different MRI sequences and is not restricted to precontrast T1-weighted (T1-w) sequences. - Other MRI sequences may work as well (just give it a try!) -- HD-BET was designed to be robust with respect to brain tumors, lesions and resection cavities as well as different -MRI scanner hardware and acquisition parameters. -- HD-BET outperformed five publicly available brain extraction algorithms (FSL BET, AFNI 3DSkullStrip, Brainsuite BSE, -ROBEX and BEaST) across all datasets and yielded median improvements of +1.33 to +2.63 points for the DICE -coefficient and -0.80 to -2.75 mm for the Hausdorff distance (Bonferroni-adjusted p<0.001). -- HD-BET is very fast on GPU with <10s run time per MRI sequence. Even on CPU it is not slower than other commonly -used tools. - -## Installation Instructions -Note that you need to have a python3 installation for HD-BET to work. Please also make sure to install HD-BET with the -correct pip version (the one that is connected to python3). You can verify this using the `--version` command: - -``` -(dl_venv) fabian@Fabian:~$ pip --version -pip 20.0.2 from /home/fabian/dl_venv/lib/python3.6/site-packages/pip (python 3.6) -``` - -If it does not show python 3.X, you can try pip3. If that also does not work you probably need to install python3 first. - -Once python 3 and pip are set up correctly, run the following commands to install HD-BET: -1) Clone this repository: - ```bash - git clone https://github.com/MIC-DKFZ/HD-BET - ``` -2) Go into the repository (the folder with the setup.py file) and install: - ``` - cd HD-BET - pip install -e . - ``` -3) Per default, model parameters will be downloaded to ~/hd-bet_params. If you wish to use a different folder, open -HD_BET/paths.py in a text editor and modify ```folder_with_parameter_files``` - - -## How to use it - -Using HD_BET is straightforward. You can use it in any terminal on your linux system. The hd-bet command was installed -automatically. We provide CPU as well as GPU support. Running on GPU is a lot faster though -and should always be preferred. Here is a minimalistic example of how you can use HD-BET (you need to be in the HD_BET -directory) - -```bash -hd-bet -i INPUT_FILENAME -``` - -INPUT_FILENAME must be a nifti (.nii.gz) file containing 3D MRI image data. 4D image sequences are not supported -(however can be splitted upfront into the individual temporal volumes using fslsplit1). -INPUT_FILENAME can be either a pre- or postcontrast T1-w, T2-w or FLAIR MRI sequence. Other modalities might work as well. -Input images must match the orientation of standard MNI152 template! Use fslreorient2std 2 upfront to ensure -that this is the case. - -By default, HD-BET will run in GPU mode, use the parameters of all five models (which originate from a five-fold -cross-validation), use test time data augmentation by mirroring along all axes and not do any postprocessing. - -For batch processing it is faster to process an entire folder at once as this will mitigate the overhead of loading -and initializing the model for each case: - -```bash -hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER -``` - -The above command will look for all nifti files (*.nii.gz) in the INPUT_FOLDER and save the brain masks under the same name -in OUTPUT_FOLDER. - -### GPU is nice, but I don't have one of those... What now? - -HD-BET has CPU support. Running on CPU takes a lot longer though and you will need quite a bit of RAM. To run on CPU, -we recommend you use the following command: - -```bash -hd-bet -i INPUT_FOLDER -o OUTPUT_FOLDER -device cpu -mode fast -tta 0 -``` -This works of course also with just an input file: - -```bash -hd-bet -i INPUT_FILENAME -device cpu -mode fast -tta 0 -``` - -The options *-mode fast* and *-tta 0* will disable test time data augmentation (speedup of 8x) and use only one model instead of an ensemble of five models -for the prediction. - -### More options: -For more information, please refer to the help functionality: - -```bash -hd-bet --help -``` - -## FAQ - -1) **How much GPU memory do I need to run HD-BET?** -We ran all our experiments on NVIDIA Titan X GPUs with 12 GB memory. For inference you will need less, but since -inference in implemented by exploiting the fully convolutional nature of CNNs the amount of memory required depends on -your image. Typical image should run with less than 4 GB of GPU memory consumption. If you run into out of memory -problems please check the following: 1) Make sure the voxel spacing of your data is correct and 2) Ensure your MRI -image only contains the head region -2) **Will you provide the training code as well?** -No. The training code is tightly wound around the data which we cannot make public. -3) **What run time can I expect on CPU/GPU?** -This depends on your MRI image size. Typical run times (preprocessing, postprocessing and resampling included) are just - a couple of seconds for GPU and about 2 Minutes on CPU (using ```-tta 0 -mode fast```) - - -1https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Fslutils - -2https://fsl.fmrib.ox.ac.uk/fsl/fslwiki/Orientation%20Explained \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7939681..0000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -numpy>=1.14.5 -torch>=0.4.0 -scikit-image>=0.14.0 -SimpleITK>=2.0.2 --e git+https://github.com/MIC-DKFZ/batchgenerators#egg=batchgenerators diff --git a/setup.py b/setup.py deleted file mode 100755 index 8e67f25..0000000 --- a/setup.py +++ /dev/null @@ -1,27 +0,0 @@ -from setuptools import setup, find_packages - -setup(name='HD_BET', - version='1.0', - description='Tool for brain extraction', - url='https://github.com/MIC-DKFZ/hd-bet', - python_requires='>=3.5', - author='Fabian Isensee', - author_email='f.isensee@dkfz.de', - license='Apache 2.0', - zip_safe=False, - install_requires=[ - 'numpy', - 'torch>=0.4.1', - 'scikit-image', - 'SimpleITK' - ], - scripts=['HD_BET/hd-bet'], - packages=find_packages(include=['HD_BET']), - classifiers=[ - 'Intended Audience :: Science/Research', - 'Programming Language :: Python', - 'Topic :: Scientific/Engineering', - 'Operating System :: Unix' - ] - ) - diff --git a/src/hd_bet/__init__.py b/src/hd_bet/__init__.py new file mode 100644 index 0000000..da9e6b6 --- /dev/null +++ b/src/hd_bet/__init__.py @@ -0,0 +1,11 @@ +""" +Copyright (c) 2024 Fabian Isensee. All rights reserved. + +HD_BET: Tool for brain extraction +""" + +from __future__ import annotations + +__version__ = "0.1.0" + +__all__ = ["__version__"] diff --git a/HD_BET/config.py b/src/hd_bet/config.py similarity index 67% rename from HD_BET/config.py rename to src/hd_bet/config.py index 870951e..ab4d41f 100755 --- a/HD_BET/config.py +++ b/src/hd_bet/config.py @@ -1,11 +1,15 @@ +from __future__ import annotations + +from abc import abstractmethod + import numpy as np import torch -from HD_BET.utils import SetNetworkToVal, softmax_helper -from abc import abstractmethod -from HD_BET.network_architecture import Network + +from hd_bet.network_architecture import Network +from hd_bet.utils import SetNetworkToVal, softmax_helper -class BaseConfig(object): +class BaseConfig: def __init__(self): pass @@ -31,8 +35,8 @@ def preprocess(self, data): def __repr__(self): res = "" for v in vars(self): - if not v.startswith("__") and not v.startswith("_") and v != 'dataset': - res += (v + ": " + str(self.__getattribute__(v)) + "\n") + if not v.startswith("__") and not v.startswith("_") and v != "dataset": + res += v + ": " + str(self.__getattribute__(v)) + "\n" return res @@ -40,7 +44,7 @@ class HD_BET_Config(BaseConfig): def __init__(self): super(HD_BET_Config, self).__init__() - self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name + self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name # network parameters self.net_base_num_layers = 21 @@ -62,13 +66,15 @@ def __init__(self): # validation self.val_use_DO = False - self.val_use_train_mode = False # for dropout sampling - self.val_num_repeats = 1 # only useful if dropout sampling - self.val_batch_size = 1 # only useful if dropout sampling + self.val_use_train_mode = False # for dropout sampling + self.val_num_repeats = 1 # only useful if dropout sampling + self.val_batch_size = 1 # only useful if dropout sampling self.val_save_npz = True - self.val_do_mirroring = True # test time data augmentation via mirroring + self.val_do_mirroring = True # test time data augmentation via mirroring self.val_write_images = True - self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property + self.net_input_must_be_divisible_by = ( + 16 # we could make a network class that has this as a property + ) self.val_min_size = self.INPUT_PATCH_SIZE self.val_fn = None @@ -78,13 +84,25 @@ def __init__(self): self.val_use_moving_averages = False def get_network(self, train=True, pretrained_weights=None): - net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers, - self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias, - self.net_norm_use_affine, True, self.net_do_DS) + net = Network( + self.num_classes, + len(self.selected_data_channels), + self.net_base_num_layers, + self.net_dropout_p, + softmax_helper, + self.net_leaky_relu_slope, + self.net_conv_use_bias, + self.net_norm_use_affine, + True, + self.net_do_DS, + ) if pretrained_weights is not None: net.load_state_dict( - torch.load(pretrained_weights, map_location=lambda storage, loc: storage)) + torch.load( + pretrained_weights, map_location=lambda storage, loc: storage + ) + ) if train: net.train(True) @@ -118,4 +136,3 @@ def preprocess(self, data): config = HD_BET_Config - diff --git a/HD_BET/data_loading.py b/src/hd_bet/data_loading.py similarity index 53% rename from HD_BET/data_loading.py rename to src/hd_bet/data_loading.py index 0a953ba..9264739 100755 --- a/HD_BET/data_loading.py +++ b/src/hd_bet/data_loading.py @@ -1,20 +1,28 @@ -import SimpleITK as sitk +from __future__ import annotations + import numpy as np -from skimage.transform import resize +import SimpleITK as sitk +import skimage.transform def resize_image(image, old_spacing, new_spacing, order=3): - new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))), - int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))), - int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2])))) - return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False) + new_shape = ( + int(np.round(old_spacing[0] / new_spacing[0] * float(image.shape[0]))), + int(np.round(old_spacing[1] / new_spacing[1] * float(image.shape[1]))), + int(np.round(old_spacing[2] / new_spacing[2] * float(image.shape[2]))), + ) + return skimage.transform.resize( + image, new_shape, order=order, mode="edge", cval=0, anti_aliasing=False + ) def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]] image = sitk.GetArrayFromImage(itk_image).astype(float) - assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed" + assert ( + len(image.shape) == 3 + ), "The image has unsupported number of dimensions. Only 3D images are allowed" if not is_seg: if np.any([[i != j] for i, j in zip(spacing, spacing_target)]): @@ -23,9 +31,11 @@ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)): image -= image.mean() image /= image.std() else: - new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), - int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), - int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2])))) + new_shape = ( + int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))), + int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))), + int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))), + ) image = resize_segmentation(image, new_shape, 1) return image @@ -39,16 +49,18 @@ def load_and_preprocess(mri_file): "spacing": images["T1"].GetSpacing(), "direction": images["T1"].GetDirection(), "size": images["T1"].GetSize(), - "origin": images["T1"].GetOrigin() + "origin": images["T1"].GetOrigin(), } - for k in images.keys(): - images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5)) + for k in images: + images[k] = preprocess_image( + images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5) + ) - properties_dict['size_before_cropping'] = images["T1"].shape + properties_dict["size_before_cropping"] = images["T1"].shape imgs = [] - for seq in ['T1']: + for seq in ["T1"]: imgs.append(images[seq][None]) all_data = np.vstack(imgs) print("image shape after preprocessing: ", str(all_data[0].shape)) @@ -56,7 +68,7 @@ def load_and_preprocess(mri_file): def save_segmentation_nifti(segmentation, dct, out_fname, order=1, dtype=np.uint8): - ''' + """ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out of the original image @@ -72,31 +84,38 @@ def save_segmentation_nifti(segmentation, dct, out_fname, order=1, dtype=np.uint :param dct: :param out_fname: :return: - ''' - old_size = dct.get('size_before_cropping') - bbox = dct.get('brain_bbox') + """ + old_size = dct.get("size_before_cropping") + bbox = dct.get("brain_bbox") if bbox is not None: seg_old_size = np.zeros(old_size) for c in range(3): bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c])) - seg_old_size[bbox[0][0]:bbox[0][1], - bbox[1][0]:bbox[1][1], - bbox[2][0]:bbox[2][1]] = segmentation + seg_old_size[ + bbox[0][0] : bbox[0][1], bbox[1][0] : bbox[1][1], bbox[2][0] : bbox[2][1] + ] = segmentation else: seg_old_size = segmentation - if np.any([i != j for i, j in zip(np.array(seg_old_size), np.array(dct['size'])[[2, 1, 0]])]): - seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order) + if np.any( + [ + i != j + for i, j in zip(np.array(seg_old_size), np.array(dct["size"])[[2, 1, 0]]) + ] + ): + seg_old_spacing = resize_segmentation( + seg_old_size, np.array(dct["size"])[[2, 1, 0]], order=order + ) else: seg_old_spacing = seg_old_size seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(dtype)) - seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]]) - seg_resized_itk.SetOrigin(dct['origin']) - seg_resized_itk.SetDirection(dct['direction']) + seg_resized_itk.SetSpacing(np.array(dct["spacing"])[[0, 1, 2]]) + seg_resized_itk.SetOrigin(dct["origin"]) + seg_resized_itk.SetDirection(dct["direction"]) sitk.WriteImage(seg_resized_itk, out_fname) def resize_segmentation(segmentation, new_shape, order=3, cval=0): - ''' + """ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one @@ -106,16 +125,33 @@ def resize_segmentation(segmentation, new_shape, order=3, cval=0): :param new_shape: :param order: :return: - ''' + """ tpe = segmentation.dtype unique_labels = np.unique(segmentation) - assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation" + assert len(segmentation.shape) == len( + new_shape + ), "new shape must have same dimensionality as segmentation" if order == 0: - return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe) + return skimage.transform.resize( + segmentation, + new_shape, + order, + mode="constant", + cval=cval, + clip=True, + anti_aliasing=False, + ).astype(tpe) else: reshaped = np.zeros(new_shape, dtype=segmentation.dtype) for i, c in enumerate(unique_labels): - reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False) + reshaped_multihot = skimage.transform.resize( + (segmentation == c).astype(float), + new_shape, + order, + mode="edge", + clip=True, + anti_aliasing=False, + ) reshaped[reshaped_multihot >= 0.5] = c return reshaped diff --git a/src/hd_bet/hd_bet_cli.py b/src/hd_bet/hd_bet_cli.py new file mode 100755 index 0000000..0748049 --- /dev/null +++ b/src/hd_bet/hd_bet_cli.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import os + +import hd_bet +from hd_bet.run import run_hd_bet +from hd_bet.utils import maybe_mkdir_p, subfiles + + +def main(): + print("\n########################") + print("If you are using hd-bet, please cite the following paper:") + print( + "Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W," + "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial" + "neural networks. arXiv preprint arXiv:1901.11341, 2019." + ) + print("########################\n") + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="input. Can be either a single file name or an input folder. If file: must be " + "nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to " + "split 4d sequences into 3d images. If folder: all files ending with .nii.gz " + "within that folder will be brain extracted.", + required=True, + type=str, + ) + parser.add_argument( + "-o", + "--output", + help="output. Can be either a filename or a folder. If it does not exist, the folder" + " will be created", + required=False, + type=str, + ) + parser.add_argument( + "-mode", + type=str, + default="accurate", + help="can be either 'fast' or 'accurate'. Fast will " + "use only one set of parameters whereas accurate will " + "use the five sets of parameters that resulted from " + "our cross-validation as an ensemble. Default: " + "accurate", + required=False, + ) + parser.add_argument( + "-device", + default="0", + type=str, + help="used to set on which device the prediction will run. " + "Must be either int or str. Use int for GPU id or " + "'cpu' to run on CPU. When using CPU you should " + "consider disabling tta. Default for -device is: 0", + required=False, + ) + parser.add_argument( + "-tta", + default=1, + required=False, + type=int, + help="whether to use test time data augmentation " + "(mirroring). 1= True, 0=False. Disable this " + "if you are using CPU to speed things up! " + "Default: 1", + ) + parser.add_argument( + "-pp", + default=1, + type=int, + required=False, + help="set to 0 to disabe postprocessing (remove all" + " but the largest connected component in " + "the prediction. Default: 1", + ) + parser.add_argument( + "-s", + "--save_mask", + default=1, + type=int, + required=False, + help="if set to 0 the segmentation " "mask will not be " "saved", + ) + parser.add_argument( + "--overwrite_existing", + default=1, + type=int, + required=False, + help="set this to 0 if you don't " "want to overwrite existing " "predictions", + ) + parser.add_argument( + "-b", + "--bet", + default=1, + type=int, + required=False, + help="set this to 0 if you don't want to save skull-stripped brain", + ) + + args = parser.parse_args() + + input_file_or_dir = args.input + output_file_or_dir = args.output + + if output_file_or_dir is None: + output_file_or_dir = os.path.join( + os.path.dirname(input_file_or_dir), + os.path.basename(input_file_or_dir).split(".")[0] + "_bet", + ) + + mode = args.mode + device = args.device + tta = args.tta + pp = args.pp + save_mask = args.save_mask + overwrite_existing = args.overwrite_existing + bet = args.bet + + params_file = os.path.join(hd_bet.__path__[0], "model_final.py") + config_file = os.path.join(hd_bet.__path__[0], "config.py") + + assert os.path.abspath(input_file_or_dir) != os.path.abspath( + output_file_or_dir + ), "output must be different from input" + + if device == "cpu": + pass + else: + device = int(device) + + if os.path.isdir(input_file_or_dir): + maybe_mkdir_p(output_file_or_dir) + input_files = subfiles(input_file_or_dir, suffix=".nii.gz", join=False) + + if len(input_files) == 0: + raise RuntimeError( + "input is a folder but no nifti files (.nii.gz) were found in here" + ) + + output_files = [os.path.join(output_file_or_dir, i) for i in input_files] + input_files = [os.path.join(input_file_or_dir, i) for i in input_files] + else: + if not output_file_or_dir.endswith(".nii.gz"): + output_file_or_dir += ".nii.gz" + assert os.path.abspath(input_file_or_dir) != os.path.abspath( + output_file_or_dir + ), "output must be different from input" + + output_files = [output_file_or_dir] + input_files = [input_file_or_dir] + + if tta == 0: + tta = False + elif tta == 1: + tta = True + else: + raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta)) + + if overwrite_existing == 0: + overwrite_existing = False + elif overwrite_existing == 1: + overwrite_existing = True + else: + raise ValueError( + "Unknown value for overwrite_existing: %s. Expected: 0 or 1" + % str(overwrite_existing) + ) + + if pp == 0: + pp = False + elif pp == 1: + pp = True + else: + raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp)) + + if save_mask == 0: + save_mask = False + elif save_mask == 1: + save_mask = True + else: + raise ValueError( + "Unknown value for save_mask: %s. Expected: 0 or 1" % str(save_mask) + ) + + if bet == 0: + if save_mask: + bet = False + else: + print("Save_mask and bet are set to 0. In this case, Bet is set to 1.") + bet = True + elif bet == 1: + bet = True + else: + raise ValueError("Unknown value for bet: %s. Expected: 0 or 1" % str(pp)) + + run_hd_bet( + input_files, + output_files, + mode, + config_file, + device, + pp, + tta, + save_mask, + overwrite_existing, + bet, + ) diff --git a/src/hd_bet/network_architecture.py b/src/hd_bet/network_architecture.py new file mode 100755 index 0000000..971a42b --- /dev/null +++ b/src/hd_bet/network_architecture.py @@ -0,0 +1,438 @@ +from __future__ import annotations + +import torch +import torch.nn.functional as F +from torch import nn + +from hd_bet.utils import softmax_helper + + +class EncodingModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + filter_size=3, + dropout_p=0.3, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): + nn.Module.__init__(self) + self.dropout_p = dropout_p + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn_1 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv1 = nn.Conv3d( + in_channels, + out_channels, + filter_size, + 1, + (filter_size - 1) // 2, + bias=self.conv_bias, + ) + self.dropout = nn.Dropout3d(dropout_p) + self.bn_2 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv2 = nn.Conv3d( + out_channels, + out_channels, + filter_size, + 1, + (filter_size - 1) // 2, + bias=self.conv_bias, + ) + + def forward(self, x): + skip = x + x = F.leaky_relu( + self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) + x = self.conv1(x) + if self.dropout_p is not None and self.dropout_p > 0: + x = self.dropout(x) + x = F.leaky_relu( + self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) + x = self.conv2(x) + x = x + skip + return x + + +class Upsample(nn.Module): + def __init__( + self, size=None, scale_factor=None, mode="nearest", align_corners=True + ): + super(Upsample, self).__init__() + self.align_corners = align_corners + self.mode = mode + self.scale_factor = scale_factor + self.size = size + + def forward(self, x): + return nn.functional.interpolate( + x, + size=self.size, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + +class LocalizationModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias) + self.bn_1 = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias) + self.bn_2 = nn.InstanceNorm3d( + out_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + + def forward(self, x): + x = F.leaky_relu( + self.bn_1(self.conv1(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = F.leaky_relu( + self.bn_2(self.conv2(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + return x + + +class UpsamplingModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True) + self.upsample_conv = nn.Conv3d( + in_channels, out_channels, 3, 1, 1, bias=self.conv_bias + ) + self.bn = nn.InstanceNorm3d( + out_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + + def forward(self, x): + x = F.leaky_relu( + self.bn(self.upsample_conv(self.upsample(x))), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + return x + + +class DownsamplingModule(nn.Module): + def __init__( + self, + in_channels, + out_channels, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ): + nn.Module.__init__(self) + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.bn = nn.InstanceNorm3d( + in_channels, affine=self.inst_norm_affine, track_running_stats=True + ) + self.downsample = nn.Conv3d( + in_channels, out_channels, 3, 2, 1, bias=self.conv_bias + ) + + def forward(self, x): + x = F.leaky_relu( + self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace + ) + b = self.downsample(x) + return x, b + + +class Network(nn.Module): + def __init__( + self, + num_classes=4, + num_input_channels=4, + base_filters=16, + dropout_p=0.3, + final_nonlin=softmax_helper, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + do_ds=True, + ): + super(Network, self).__init__() + + self.do_ds = do_ds + self.lrelu_inplace = lrelu_inplace + self.inst_norm_affine = inst_norm_affine + self.conv_bias = conv_bias + self.leakiness = leakiness + self.final_nonlin = final_nonlin + self.init_conv = nn.Conv3d( + num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias + ) + + self.context1 = EncodingModule( + base_filters, + base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down1 = DownsamplingModule( + base_filters, + base_filters * 2, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context2 = EncodingModule( + 2 * base_filters, + 2 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down2 = DownsamplingModule( + 2 * base_filters, + base_filters * 4, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context3 = EncodingModule( + 4 * base_filters, + 4 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down3 = DownsamplingModule( + 4 * base_filters, + base_filters * 8, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context4 = EncodingModule( + 8 * base_filters, + 8 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.down4 = DownsamplingModule( + 8 * base_filters, + base_filters * 16, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.context5 = EncodingModule( + 16 * base_filters, + 16 * base_filters, + 3, + dropout_p, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.bn_after_context5 = nn.InstanceNorm3d( + 16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) + self.up1 = UpsamplingModule( + 16 * base_filters, + 8 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc1 = LocalizationModule( + 16 * base_filters, + 8 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.up2 = UpsamplingModule( + 8 * base_filters, + 4 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc2 = LocalizationModule( + 8 * base_filters, + 4 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up3 = UpsamplingModule( + 4 * base_filters, + 2 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.loc3 = LocalizationModule( + 4 * base_filters, + 2 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + self.up4 = UpsamplingModule( + 2 * base_filters, + 1 * base_filters, + leakiness=1e-2, + conv_bias=True, + inst_norm_affine=True, + lrelu_inplace=True, + ) + + self.end_conv_1 = nn.Conv3d( + 2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias + ) + self.end_conv_1_bn = nn.InstanceNorm3d( + 2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) + self.end_conv_2 = nn.Conv3d( + 2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias + ) + self.end_conv_2_bn = nn.InstanceNorm3d( + 2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True + ) + self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False) + + def forward(self, x): + seg_outputs = [] + + x = self.init_conv(x) + x = self.context1(x) + + skip1, x = self.down1(x) + x = self.context2(x) + + skip2, x = self.down2(x) + x = self.context3(x) + + skip3, x = self.down3(x) + x = self.context4(x) + + skip4, x = self.down4(x) + x = self.context5(x) + + x = F.leaky_relu( + self.bn_after_context5(x), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = self.up1(x) + + x = torch.cat((skip4, x), dim=1) + x = self.loc1(x) + x = self.up2(x) + + x = torch.cat((skip3, x), dim=1) + x = self.loc2(x) + loc2_seg = self.final_nonlin(self.loc2_seg(x)) + seg_outputs.append(loc2_seg) + x = self.up3(x) + + x = torch.cat((skip2, x), dim=1) + x = self.loc3(x) + loc3_seg = self.final_nonlin(self.loc3_seg(x)) + seg_outputs.append(loc3_seg) + x = self.up4(x) + + x = torch.cat((skip1, x), dim=1) + x = F.leaky_relu( + self.end_conv_1_bn(self.end_conv_1(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = F.leaky_relu( + self.end_conv_2_bn(self.end_conv_2(x)), + negative_slope=self.leakiness, + inplace=self.lrelu_inplace, + ) + x = self.final_nonlin(self.seg_layer(x)) + seg_outputs.append(x) + + if self.do_ds: + return seg_outputs[::-1] + else: + return seg_outputs[-1] diff --git a/src/hd_bet/paths.py b/src/hd_bet/paths.py new file mode 100644 index 0000000..e93a734 --- /dev/null +++ b/src/hd_bet/paths.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +import os + +# please refer to the readme on where to get the parameters. Save them in this folder: +folder_with_parameter_files = os.path.join(os.path.expanduser("~"), "hd-bet_params") diff --git a/HD_BET/predict_case.py b/src/hd_bet/predict_case.py similarity index 70% rename from HD_BET/predict_case.py rename to src/hd_bet/predict_case.py index 559c667..f7e2268 100755 --- a/HD_BET/predict_case.py +++ b/src/hd_bet/predict_case.py @@ -1,14 +1,18 @@ -import torch +from __future__ import annotations + import numpy as np +import torch def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): - if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)): + if not (isinstance(shape_must_be_divisible_by, (list, tuple))): shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3 shp = patient.shape - new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], - shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], - shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]] + new_shp = [ + shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0], + shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1], + shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2], + ] for i in range(len(shp)): if shp[i] % shape_must_be_divisible_by[i] == 0: new_shp[i] -= shape_must_be_divisible_by[i] @@ -19,28 +23,41 @@ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None): def reshape_by_padding_upper_coords(image, new_shape, pad_value=None): shape = tuple(list(image.shape)) - new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0)) + new_shape = tuple( + np.max(np.concatenate((shape, new_shape)).reshape((2, len(shape))), axis=0) + ) if pad_value is None: if len(shape) == 2: - pad_value = image[0,0] + pad_value = image[0, 0] elif len(shape) == 3: pad_value = image[0, 0, 0] else: raise ValueError("Image must be either 2 or 3 dimensional") res = np.ones(list(new_shape), dtype=image.dtype) * pad_value if len(shape) == 2: - res[0:0+int(shape[0]), 0:0+int(shape[1])] = image + res[0 : 0 + int(shape[0]), 0 : 0 + int(shape[1])] = image elif len(shape) == 3: - res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image + res[0 : 0 + int(shape[0]), 0 : 0 + int(shape[1]), 0 : 0 + int(shape[2])] = image return res -def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None, - new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)): +def predict_case_3D_net( + net, + patient_data, + do_mirroring, + num_repeats, + BATCH_SIZE=None, + new_shape_must_be_divisible_by=16, + min_size=None, + main_device=0, + mirror_axes=(2, 3, 4), +): with torch.no_grad(): pad_res = [] for i in range(patient_data.shape[0]): - t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size) + t, old_shape = pad_patient_3D( + patient_data[i], new_shape_must_be_divisible_by, min_size + ) pad_res.append(t[None]) patient_data = np.vstack(pad_res) @@ -56,7 +73,7 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE a = torch.rand(data.shape).float() - if main_device == 'cpu': + if main_device == "cpu": pass else: a = a.cuda(main_device) @@ -72,7 +89,6 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE do_stuff = False if m == 0: do_stuff = True - pass if m == 1 and (4 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, :, :, ::-1] @@ -91,13 +107,20 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, :] - if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + if ( + m == 7 + and (2 in mirror_axes) + and (3 in mirror_axes) + and (4 in mirror_axes) + ): do_stuff = True data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1] if do_stuff: _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net))) - p = net(a) # np.copy is necessary because ::-1 creates just a view i think + p = net( + a + ) # np.copy is necessary because ::-1 creates just a view i think p = p.data.cpu().numpy() if m == 0: @@ -114,11 +137,18 @@ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE p = p[:, :, ::-1, :, ::-1] if m == 6 and (2 in mirror_axes) and (3 in mirror_axes): p = p[:, :, ::-1, ::-1, :] - if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes): + if ( + m == 7 + and (2 in mirror_axes) + and (3 in mirror_axes) + and (4 in mirror_axes) + ): p = p[:, :, ::-1, ::-1, ::-1] all_preds.append(p) - stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]] + stacked = np.vstack(all_preds)[ + :, :, : old_shape[0], : old_shape[1], : old_shape[2] + ] predicted_segmentation = stacked.mean(0).argmax(0) uncertainty = stacked.var(0) bayesian_predictions = stacked diff --git a/HD_BET/__init__.py b/src/hd_bet/py.typed similarity index 100% rename from HD_BET/__init__.py rename to src/hd_bet/py.typed diff --git a/HD_BET/run.py b/src/hd_bet/run.py similarity index 67% rename from HD_BET/run.py rename to src/hd_bet/run.py index 8c6f08d..7830e52 100755 --- a/HD_BET/run.py +++ b/src/hd_bet/run.py @@ -1,12 +1,21 @@ -import torch -import numpy as np -import SimpleITK as sitk -from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti -from HD_BET.predict_case import predict_case_3D_net +from __future__ import annotations + import imp -from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters import os -import HD_BET + +import numpy as np +import SimpleITK as sitk +import torch + +import hd_bet +from hd_bet.data_loading import load_and_preprocess, save_segmentation_nifti +from hd_bet.predict_case import predict_case_3D_net +from hd_bet.utils import ( + SetNetworkToVal, + get_params_fname, + maybe_download_parameters, + postprocess_prediction, +) def apply_bet(img, bet, out_fname): @@ -19,8 +28,18 @@ def apply_bet(img, bet, out_fname): sitk.WriteImage(out, out_fname) -def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, - postprocess=False, do_tta=True, keep_mask=True, overwrite=True, bet=False): +def run_hd_bet( + mri_fnames, + output_fnames, + mode="accurate", + config_file=os.path.join(hd_bet.__path__[0], "config.py"), + device=0, + postprocess=False, + do_tta=True, + keep_mask=True, + overwrite=True, + bet=False, +): """ :param mri_fnames: str or list/tuple of str @@ -37,23 +56,27 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j list_of_param_files = [] - if mode == 'fast': + if mode == "fast": params_file = get_params_fname(0) maybe_download_parameters(0) list_of_param_files.append(params_file) - elif mode == 'accurate': + elif mode == "accurate": for i in range(5): params_file = get_params_fname(i) maybe_download_parameters(i) list_of_param_files.append(params_file) else: - raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) + raise ValueError( + "Unknown value for mode: %s. Expected: fast or accurate" % mode + ) - assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" + assert all( + [os.path.isfile(i) for i in list_of_param_files] + ), "Could not find parameter files" - cf = imp.load_source('cf', config_file) + cf = imp.load_source("cf", config_file) cf = cf.config() net, _ = cf.get_network(cf.val_use_train_mode, None) @@ -68,7 +91,9 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j if not isinstance(output_fnames, (list, tuple)): output_fnames = [output_fnames] - assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" + assert len(mri_fnames) == len( + output_fnames + ), "mri_fnames and output_fnames must have the same length" params = [] for p in list_of_param_files: @@ -76,7 +101,10 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j for in_fname, out_fname in zip(mri_fnames, output_fnames): mask_fname = out_fname[:-7] + "_mask.nii.gz" - if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): + if overwrite or ( + not (os.path.isfile(mask_fname) and keep_mask) + or not os.path.isfile(out_fname) + ): print("File:", in_fname) print("preprocessing...") try: @@ -96,9 +124,17 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j net.load_state_dict(p) net.eval() net.apply(SetNetworkToVal(False, False)) - _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, - cf.val_batch_size, cf.net_input_must_be_divisible_by, - cf.val_min_size, device, cf.da_mirror_axes) + _, _, softmax_pred, _ = predict_case_3D_net( + net, + data, + do_tta, + cf.val_num_repeats, + cf.val_batch_size, + cf.net_input_must_be_divisible_by, + cf.val_min_size, + device, + cf.da_mirror_axes, + ) softmax_preds.append(softmax_pred[None]) seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) @@ -113,5 +149,3 @@ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.j if not keep_mask: os.remove(mask_fname) - - diff --git a/HD_BET/utils.py b/src/hd_bet/utils.py similarity index 70% rename from HD_BET/utils.py rename to src/hd_bet/utils.py index f70f389..a918f7c 100755 --- a/HD_BET/utils.py +++ b/src/hd_bet/utils.py @@ -1,10 +1,14 @@ +from __future__ import annotations + +import os from urllib.request import urlopen -import torch -from torch import nn + import numpy as np +import torch from skimage.morphology import label -import os -from HD_BET.paths import folder_with_parameter_files +from torch import nn + +from hd_bet.paths import folder_with_parameter_files def get_params_fname(fold): @@ -33,7 +37,7 @@ def maybe_download_parameters(fold=0, force_overwrite=False): url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold print("Downloading", url, "...") data = urlopen(url).read() - with open(out_filename, 'wb') as f: + with open(out_filename, "wb") as f: f.write(data) @@ -52,18 +56,25 @@ def softmax_helper(x): return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) -class SetNetworkToVal(object): +class SetNetworkToVal: def __init__(self, use_dropout_sampling=False, norm_use_average=True): self.norm_use_average = norm_use_average self.use_dropout_sampling = use_dropout_sampling def __call__(self, module): - if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): + if isinstance(module, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)): module.train(self.use_dropout_sampling) - elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ - isinstance(module, nn.InstanceNorm1d) \ - or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ - isinstance(module, nn.BatchNorm1d): + elif isinstance( + module, + ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + ), + ): module.train(not self.norm_use_average) @@ -83,9 +94,13 @@ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): l = os.path.join else: l = lambda x, y: y - res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) - and (prefix is None or i.startswith(prefix)) - and (suffix is None or i.endswith(suffix))] + res = [ + l(folder, i) + for i in os.listdir(folder) + if os.path.isdir(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix)) + ] if sort: res.sort() return res @@ -96,9 +111,13 @@ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): l = os.path.join else: l = lambda x, y: y - res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) - and (prefix is None or i.startswith(prefix)) - and (suffix is None or i.endswith(suffix))] + res = [ + l(folder, i) + for i in os.listdir(folder) + if os.path.isfile(os.path.join(folder, i)) + and (prefix is None or i.startswith(prefix)) + and (suffix is None or i.endswith(suffix)) + ] if sort: res.sort() return res @@ -109,6 +128,6 @@ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): def maybe_mkdir_p(directory): splits = directory.split("/")[1:] - for i in range(0, len(splits)): - if not os.path.isdir(os.path.join("/", *splits[:i+1])): - os.mkdir(os.path.join("/", *splits[:i+1])) + for i in range(len(splits)): + if not os.path.isdir(os.path.join("/", *splits[: i + 1])): + os.mkdir(os.path.join("/", *splits[: i + 1])) diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 0000000..2255f2c --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +import importlib.metadata + +import hd_bet as m + + +def test_version(): + assert importlib.metadata.version("hd_bet") == m.__version__