diff --git a/.github/dco.yml b/.github/dco.yml
new file mode 100644
index 00000000000..7993b95cc24
--- /dev/null
+++ b/.github/dco.yml
@@ -0,0 +1,2 @@
+allowRemediationCommits:
+ individual: true
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index e1eeb92c6b7..9da78d18cd9 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -1,10 +1,8 @@
Fixes # .
### Description
-A few sentences describing the changes proposed in this pull request.
-### Status
-**Ready/Work in progress/Hold**
+A few sentences describing the changes proposed in this pull request.
### Types of changes
diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml
index 14f768111e0..1717618e9cd 100644
--- a/.github/workflows/blossom-ci.yml
+++ b/.github/workflows/blossom-ci.yml
@@ -59,7 +59,7 @@ jobs:
#- name: Setup java
# uses: actions/setup-java@v1
# with:
- # java-version: 1.8
+ # java-version: '1.8'
# add blackduck properties https://synopsys.atlassian.net/wiki/spaces/INTDOCS/pages/631308372/Methods+for+Configuring+Analysis#Using-a-configuration-file
#- name: Setup blackduck properties
diff --git a/.github/workflows/chatops.yml b/.github/workflows/chatops.yml
index 9458fff298c..ac6bc1eb347 100644
--- a/.github/workflows/chatops.yml
+++ b/.github/workflows/chatops.yml
@@ -1,3 +1,4 @@
+# triggering the workflows by commenting `/black` and `/integration-test`
name: chatops
# currently dispatches /black command to project-monai/monai-code-formatter
diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index 98c194f474f..c2a78b67933 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -1,4 +1,5 @@
-name: conda
+# daily tests for different OS with conda
+name: cron-conda
on:
schedule:
@@ -18,7 +19,7 @@ jobs:
fail-fast: false
matrix:
os: [windows-latest, macOS-latest, ubuntu-latest]
- python-version: ["3.7"]
+ python-version: ["3.8"]
runs-on: ${{ matrix.os }}
env:
QUICKTEST: True
diff --git a/.github/workflows/cron-mmar.yml b/.github/workflows/cron-mmar.yml
index 46bf7ff384f..95167fba3dd 100644
--- a/.github/workflows/cron-mmar.yml
+++ b/.github/workflows/cron-mmar.yml
@@ -1,3 +1,4 @@
+# daily tests for clara mmar models
name: cron-mmar
on:
@@ -18,12 +19,12 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: cache weekly timestamp
id: pip-cache
- run: echo "::set-output name=datew::$(date '+%Y-%V')"
+ run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml
index b143b26ce0b..9410d5d58dd 100644
--- a/.github/workflows/cron.yml
+++ b/.github/workflows/cron.yml
@@ -1,4 +1,5 @@
-name: crons
+# nightly: Jenkinsfile.monai-pytorch-versions, monai-latest-image, monai-pip, monai-latest-docker, monai-notebooks
+name: nightly-crons
on:
# schedule:
@@ -9,31 +10,86 @@ on:
jobs:
cron-gpu:
if: github.repository == 'Project-MONAI/MONAI'
+ strategy:
+ matrix:
+ environment:
+ - "PT182+CUDA102"
+ - "PT191+CUDA113"
+ - "PT110+CUDA113"
+ - "PT112+CUDA113"
+ - "PTLATEST+CUDA117"
+ include:
+ # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
+ - environment: PT182+CUDA102
+ pytorch: "torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu102"
+ base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
+ - environment: PT191+CUDA113
+ pytorch: "torch==1.9.1 torchvision==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113"
+ base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
+ - environment: PT110+CUDA113
+ pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113"
+ base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
+ - environment: PT112+CUDA113
+ pytorch: "torch==1.12.1 torchvision==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu113"
+ base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
+ - environment: PTLATEST+CUDA117
+ pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117"
+ base: "nvcr.io/nvidia/pytorch:22.08-py3" # CUDA 11.7
container:
- image: nvcr.io/nvidia/pytorch:21.06-py3 # CUDA 11.3
+ image: ${{ matrix.base }}
options: "--gpus all"
runs-on: [self-hosted, linux, x64, common]
- strategy:
- matrix:
- pytorch-version: [1.7.1, 1.8.1, 1.9.1, 1.10.2, latest]
steps:
- uses: actions/checkout@v3
+ - name: apt install
+ run: |
+ # FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/4200
+ apt-key del 7fa2af80 && rm -rf /etc/apt/sources.list.d/nvidia-ml.list /etc/apt/sources.list.d/cuda.list
+ apt-get update
+ apt-get install -y wget
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
+ dpkg -i cuda-keyring_1.0-1_all.deb
+
+ if [ ${{ matrix.environment }} = "PT182+CUDA102" ]
+ then
+ PYVER=3.7 PYSFX=3 DISTUTILS=python3-distutils && \
+ apt-get update && apt-get install -y --no-install-recommends \
+ curl \
+ pkg-config \
+ python$PYVER \
+ python$PYVER-dev \
+ python$PYSFX-pip \
+ $DISTUTILS \
+ rsync \
+ swig \
+ unzip \
+ zip \
+ zlib1g-dev \
+ libboost-locale-dev \
+ libboost-program-options-dev \
+ libboost-system-dev \
+ libboost-thread-dev \
+ libboost-test-dev \
+ libgoogle-glog-dev \
+ libjsoncpp-dev \
+ cmake \
+ git && \
+ rm -rf /var/lib/apt/lists/* && \
+ export PYTHONIOENCODING=utf-8 LC_ALL=C.UTF-8 && \
+ rm -f /usr/bin/python && \
+ rm -f /usr/bin/python`echo $PYVER | cut -c1-1` && \
+ ln -s /usr/bin/python$PYVER /usr/bin/python && \
+ ln -s /usr/bin/python$PYVER /usr/bin/python`echo $PYVER | cut -c1-1` &&
+ curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py;
+ fi
- name: Install the dependencies
run: |
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
- if [ ${{ matrix.pytorch-version }} == "latest" ]; then
- python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
- elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then
- python -m pip install torch==1.7.1 torchvision==0.8.2 --extra-index-url https://download.pytorch.org/whl/cu113
- elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then
- python -m pip install torch==1.8.1 torchvision==0.9.1 --extra-index-url https://download.pytorch.org/whl/cu113
- elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
- python -m pip install torch==1.9.1 torchvision==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113
- elif [ ${{ matrix.pytorch-version }} == "1.10.2" ]; then
- python -m pip install torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113
- fi
+ python -m pip install ${{ matrix.pytorch }}
python -m pip install -r requirements-dev.txt
python -m pip list
- name: Run tests report coverage
@@ -42,7 +98,7 @@ jobs:
echo "Sleep $LAUNCH_DELAY"
sleep $LAUNCH_DELAY
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -50,23 +106,24 @@ jobs:
python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report
BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report
- coverage xml
+ coverage xml --ignore-errors
if pgrep python; then pkill python; fi
+ shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
- file: ./coverage.xml
+ files: ./coverage.xml
cron-pt-image:
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
- container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.07"] # 21.02, 21.10 for backward comp.
+ container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.10"] # 21.02, 21.10 for backward comp.
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
- runs-on: [self-hosted, linux, x64, common]
+ runs-on: [self-hosted, linux, x64, integration]
steps:
- uses: actions/checkout@v3
- name: Install APT dependencies
@@ -85,7 +142,7 @@ jobs:
echo "Sleep $LAUNCH_DELAY"
sleep $LAUNCH_DELAY
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -93,24 +150,25 @@ jobs:
python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report
BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report
- coverage xml
+ coverage xml --ignore-errors
if pgrep python; then pkill python; fi
+ shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
- file: ./coverage.xml
+ files: ./coverage.xml
cron-pip:
# pip install monai[all] and use it to run unit tests
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
- container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.07"] # 21.02, 21.10 for backward comp.
+ container: ["pytorch:21.02", "pytorch:21.10", "pytorch:22.10"] # 21.02, 21.10 for backward comp.
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
- runs-on: [self-hosted, linux, x64, common]
+ runs-on: [self-hosted, linux, x64, integration]
steps:
- uses: actions/checkout@v3
with:
@@ -121,6 +179,7 @@ jobs:
python -m pip install --upgrade pip wheel twine
python -m pip list
- name: Run tests report coverage
+ shell: bash
run: |
pip uninstall monai
pip list | grep -iv monai
@@ -160,7 +219,7 @@ jobs:
echo "Sleep $LAUNCH_DELAY"
sleep $LAUNCH_DELAY
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -175,7 +234,7 @@ jobs:
container:
image: docker://projectmonai/monai:latest # this might be slow and has the pull count limitations
options: "--gpus all"
- runs-on: [self-hosted, linux, x64, common]
+ runs-on: [self-hosted, linux, x64, integration]
steps:
- name: Run tests report coverage
# The docker image process has done the compilation.
@@ -183,7 +242,7 @@ jobs:
run: |
cd /opt/monai
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -192,21 +251,22 @@ jobs:
ngc --version
BUILD_MONAI=1 ./runtests.sh --build --coverage --pytype --unittests --disttests # unit tests with pytype checks, coverage report
BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report
- coverage xml
+ coverage xml --ignore-errors
if pgrep python; then pkill python; fi
+ shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
- file: ./coverage.xml
+ files: ./coverage.xml
cron-tutorial-notebooks:
if: github.repository == 'Project-MONAI/MONAI'
needs: cron-gpu # so that monai itself is verified first
container:
- image: nvcr.io/nvidia/pytorch:22.07-py3 # testing with the latest pytorch base image
+ image: nvcr.io/nvidia/pytorch:22.10-py3 # testing with the latest pytorch base image
options: "--gpus all --ipc=host"
- runs-on: [self-hosted, linux, x64, common]
+ runs-on: [self-hosted, linux, x64, integration]
steps:
- uses: actions/checkout@v3
- name: Install MONAI
@@ -217,9 +277,9 @@ jobs:
python -m pip install -r requirements-dev.txt
BUILD_MONAI=1 python setup.py develop # install monai
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
- echo "::set-output name=devices::$CUDA_VISIBLE_DEVICES"
+ echo "devices=$CUDA_VISIBLE_DEVICES" >> $GITHUB_OUTPUT
- name: Checkout tutorials and install their requirements
run: |
cd /opt
@@ -238,3 +298,4 @@ jobs:
$(pwd)/runner.sh
python -c 'import monai; monai.config.print_debug_info()'
if pgrep python; then pkill python; fi
+ shell: bash
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 4ca1b182618..b88923e43d4 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -1,3 +1,4 @@
+# this is the docker image releasing pipeline, pushing to https://hub.docker.com/r/projectmonai/monai
name: docker
# versioning: compute a static version file
# local_docker: use the version file to build docker images
@@ -25,9 +26,9 @@ jobs:
ref: dev
fetch-depth: 0
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- shell: bash
run: |
git describe
@@ -91,7 +92,7 @@ jobs:
steps:
- name: Import
run: |
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
python -c 'import monai; monai.config.print_debug_info()'
cd /opt/monai
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index 951b09d082b..9000268d0a6 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -1,3 +1,4 @@
+# manually trigger integration with the latest pytorch
name: integration
on:
@@ -7,9 +8,9 @@ on:
jobs:
integration-py3:
container:
- image: nvcr.io/nvidia/pytorch:22.04-py3 # CUDA 11.6
+ image: nvcr.io/nvidia/pytorch:22.04-py3 # CUDA 11.6 py38
options: --gpus all # shm-size 4g works fine
- runs-on: [self-hosted, linux, x64, common]
+ runs-on: [self-hosted, linux, x64, integration]
steps:
# checkout the pull request branch
- uses: actions/checkout@v3
@@ -20,7 +21,7 @@ jobs:
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
@@ -34,14 +35,14 @@ jobs:
which python
python -m pip install --upgrade pip wheel
python -m pip uninstall -y torch torchvision
- python -m pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -r requirements-dev.txt
rm -rf /github/home/.cache/torch/hub/mmars/
- name: Run integration tests
run: |
python -m pip list
nvidia-smi
- export CUDA_VISIBLE_DEVICES=$(python -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -53,6 +54,7 @@ jobs:
shell: bash
- name: Add reaction
uses: peter-evans/create-or-update-comment@v1
+ if: github.event.pull_request.number != ''
with:
token: ${{ secrets.PR_MAINTAIN }}
repository: ${{ github.event.client_payload.github.payload.repository.full_name }}
diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml
index 8a78a49f0c4..9541bd1caa9 100644
--- a/.github/workflows/pythonapp-gpu.yml
+++ b/.github/workflows/pythonapp-gpu.yml
@@ -1,4 +1,5 @@
-name: build-gpu
+# Jenkinsfile.monai-premerge
+name: premerge-gpu
on:
# quick tests for pull requests and the releasing branches
@@ -7,6 +8,7 @@ on:
- main
- releasing/*
pull_request:
+ types: [opened, synchronize, closed]
concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
@@ -15,66 +17,54 @@ concurrency:
jobs:
GPU-quick-py3: # GPU with full dependencies
- if: github.repository == 'Project-MONAI/MONAI'
+ if: ${{ github.repository == 'Project-MONAI/MONAI' && github.event.pull_request.merged != true }}
strategy:
matrix:
environment:
- - "PT19+CUDA114"
- - "PT17+CUDA102"
- "PT18+CUDA102"
- - "PT18+CUDA112"
- - "PT112+CUDA117"
- - "PT110+CUDA102"
- - "PT112+CUDA102"
+ - "PT19+CUDA114DOCKER"
+ - "PT110+CUDA111"
+ - "PT112+CUDA118DOCKER"
+ - "PT113+CUDA116"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- - environment: PT17+CUDA102
- pytorch: "torch==1.7.1 torchvision==0.8.2"
- base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
- environment: PT18+CUDA102
# pytorch 1.8.2 LTS
pytorch: "torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu102"
base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
- - environment: PT18+CUDA112
- # we explicitly set pytorch to -h to avoid pip install error
- # 21.03: 1.9.0a0+df837d0
- pytorch: "-h"
- base: "nvcr.io/nvidia/pytorch:21.03-py3"
- - environment: PT19+CUDA114
- # we explicitly set pytorch to -h to avoid pip install error
+ - environment: PT19+CUDA114DOCKER
# 21.10: 1.10.0a0+0aef44c
- pytorch: "-h"
+ pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
base: "nvcr.io/nvidia/pytorch:21.10-py3"
- - environment: PT112+CUDA117
- # we explicitly set pytorch to -h to avoid pip install error
- # 22.07: 1.13.0a0+08820cb
- pytorch: "-h"
- base: "nvcr.io/nvidia/pytorch:22.07-py3"
- - environment: PT110+CUDA102
- pytorch: "torch==1.10.2 torchvision==0.11.3"
- base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
- - environment: PT112+CUDA102
- pytorch: "torch==1.12.0 torchvision==0.13.0"
- base: "nvcr.io/nvidia/cuda:10.2-devel-ubuntu18.04"
+ - environment: PT110+CUDA111
+ pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu111"
+ base: "nvcr.io/nvidia/cuda:11.1.1-devel-ubuntu18.04"
+ - environment: PT112+CUDA118DOCKER
+ # 22.09: 1.13.0a0+d0d6b1f
+ pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
+ base: "nvcr.io/nvidia/pytorch:22.09-py3"
+ - environment: PT113+CUDA116
+ pytorch: "torch==1.13.0 torchvision==0.14.0"
+ base: "nvcr.io/nvidia/cuda:11.6.1-devel-ubuntu18.04"
container:
image: ${{ matrix.base }}
- options: --gpus all
+ options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6
runs-on: [self-hosted, linux, x64, common]
steps:
- uses: actions/checkout@v3
- name: apt install
+ if: github.event.pull_request.merged != true
run: |
- # workaround for https://github.com/Project-MONAI/MONAI/issues/4200
+ # FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/4200
apt-key del 7fa2af80 && rm -rf /etc/apt/sources.list.d/nvidia-ml.list /etc/apt/sources.list.d/cuda.list
apt-get update
apt-get install -y wget
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/cuda-keyring_1.0-1_all.deb
dpkg -i cuda-keyring_1.0-1_all.deb
- if [ ${{ matrix.environment }} = "PT17+CUDA102" ] || \
- [ ${{ matrix.environment }} = "PT18+CUDA102" ] || \
- [ ${{ matrix.environment }} = "PT110+CUDA102" ] || \
- [ ${{ matrix.environment }} = "PT112+CUDA102" ]
+ if [ ${{ matrix.environment }} = "PT18+CUDA102" ] || \
+ [ ${{ matrix.environment }} = "PT110+CUDA111" ] || \
+ [ ${{ matrix.environment }} = "PT113+CUDA116" ]
then
PYVER=3.7 PYSFX=3 DISTUTILS=python3-distutils && \
apt-get update && apt-get install -y --no-install-recommends \
@@ -108,10 +98,11 @@ jobs:
python get-pip.py && \
rm get-pip.py;
fi
- - if: matrix.environment == 'PT19+CUDA114'
+ - if: matrix.environment == 'PT19+CUDA114DOCKER'
name: Optional Cupy dependency (cuda114)
run: echo "cupy-cuda114" >> requirements-dev.txt
- name: Install dependencies
+ if: github.event.pull_request.merged != true
run: |
which python
python -m pip install --upgrade pip wheel
@@ -120,7 +111,11 @@ jobs:
python -m pip install ${{ matrix.pytorch }}
python -m pip install -r requirements-dev.txt
python -m pip list
+ DRIVER_VERSION=$(cat /proc/driver/nvidia/version | head -1 | awk -F' ' '{print $8}')
+ ls -ltr /usr/lib/x86_64-linux-gnu/libcuda* && ln -fs /usr/lib/x86_64-linux-gnu/libcuda.so.$DRIVER_VERSION /usr/lib/x86_64-linux-gnu/libcuda.so.1 || true
+ ls -ltr /usr/lib64/libcuda* && ln -fs /usr/lib64/libcuda.so.$DRIVER_VERSION /usr/lib64/libcuda.so.1 || true
- name: Run quick tests (GPU)
+ if: github.event.pull_request.merged != true
run: |
git clone --depth 1 \
https://github.com/Project-MONAI/MONAI-extra-test-data.git /MONAI-extra-test-data
@@ -129,7 +124,7 @@ jobs:
export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 10)")
echo "Sleep $LAUNCH_DELAY"
sleep $LAUNCH_DELAY
- export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils)
+ export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils | tail -n 1)
echo $CUDA_VISIBLE_DEVICES
trap 'if pgrep python; then pkill python; fi;' ERR
python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
@@ -137,15 +132,17 @@ jobs:
python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
python -c "import monai; monai.config.print_config()"
# build for the current self-hosted CI Tesla V100
- BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --quick --unittests --disttests
- if [ ${{ matrix.environment }} = "PT110+CUDA102" ]; then
+ BUILD_MONAI=1 TORCH_CUDA_ARCH_LIST="7.0" ./runtests.sh --build --disttests
+ ./runtests.sh --quick --unittests
+ if [ ${{ matrix.environment }} = "PT18+CUDA102" ]; then
# test the clang-format tool downloading once
coverage run -m tests.clang_format_utils
fi
- coverage xml
+ coverage xml --ignore-errors
if pgrep python; then pkill python; fi
shell: bash
- name: Upload coverage
- uses: codecov/codecov-action@v1
+ if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }}
+ uses: codecov/codecov-action@v3
with:
- file: ./coverage.xml
+ files: ./coverage.xml
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
index 8a9f4dbe679..317fcbbcd2a 100644
--- a/.github/workflows/pythonapp-min.yml
+++ b/.github/workflows/pythonapp-min.yml
@@ -1,4 +1,5 @@
-name: build-min
+# Jenkinsfile.monai-premerge
+name: premerge-min
on:
# quick tests for pull requests and the releasing branches
@@ -29,9 +30,9 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: Prepare pip wheel
run: |
which python
@@ -39,8 +40,8 @@ jobs:
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
- echo "::set-output name=dir::$(pip cache dir)"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
uses: actions/cache@v3
@@ -51,11 +52,11 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
- python -m pip install torch==1.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install torch==1.13+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install the dependencies
run: |
# min. requirements
- python -m pip install torch==1.12.0
+ python -m pip install torch==1.13
python -m pip install -r requirements-min.txt
python -m pip list
BUILD_MONAI=0 python setup.py develop # no compile of extensions
@@ -74,12 +75,12 @@ jobs:
strategy:
fail-fast: false
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
timeout-minutes: 40
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Prepare pip wheel
@@ -89,8 +90,8 @@ jobs:
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
- echo "::set-output name=dir::$(pip cache dir)"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
uses: actions/cache@v3
@@ -101,7 +102,7 @@ jobs:
- name: Install the dependencies
run: |
# min. requirements
- python -m pip install torch==1.12.0
+ python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -r requirements-min.txt
python -m pip list
BUILD_MONAI=0 python setup.py develop # no compile of extensions
@@ -119,14 +120,14 @@ jobs:
strategy:
fail-fast: false
matrix:
- pytorch-version: [1.7.1, 1.8.2, 1.9.1, 1.10.2, 1.11.0, latest]
+ pytorch-version: ['1.8.2', '1.9.1', '1.10.2', '1.11.0', '1.12.1', 'latest']
timeout-minutes: 40
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: Prepare pip wheel
run: |
which python
@@ -134,8 +135,8 @@ jobs:
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
- echo "::set-output name=dir::$(pip cache dir)"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
uses: actions/cache@v3
@@ -148,8 +149,6 @@ jobs:
# min. requirements
if [ ${{ matrix.pytorch-version }} == "latest" ]; then
python -m pip install torch
- elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then
- python -m pip install torch==1.7.1
elif [ ${{ matrix.pytorch-version }} == "1.8.2" ]; then
python -m pip install torch==1.8.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cpu
elif [ ${{ matrix.pytorch-version }} == "1.9.1" ]; then
@@ -158,6 +157,8 @@ jobs:
python -m pip install torch==1.10.2
elif [ ${{ matrix.pytorch-version }} == "1.11.0" ]; then
python -m pip install torch==1.11.0
+ elif [ ${{ matrix.pytorch-version }} == "1.12.1" ]; then
+ python -m pip install torch==1.12.1
fi
python -m pip install -r requirements-min.txt
python -m pip list
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index a0b1070bd4c..f3a3fba46ee 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -1,4 +1,5 @@
-name: build
+# Jenkinsfile.monai-premerge
+name: premerge
on:
# quick tests for pull requests and the releasing branches
@@ -24,13 +25,13 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
@@ -54,7 +55,7 @@ jobs:
fail-fast: false
matrix:
os: [windows-latest, macOS-latest, ubuntu-latest]
- timeout-minutes: 60
+ timeout-minutes: 120
steps:
- if: runner.os == 'windows'
name: Config pagefile (Windows only)
@@ -65,9 +66,9 @@ jobs:
disk-root: "D:"
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: Prepare pip wheel
run: |
which python
@@ -75,8 +76,8 @@ jobs:
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
- echo "::set-output name=dir::$(pip cache dir)"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash
- name: cache for pip
uses: actions/cache@v3
@@ -87,10 +88,14 @@ jobs:
- if: runner.os == 'windows'
name: Install torch cpu from pytorch.org (Windows only)
run: |
- python -m pip install torch==1.12.0+cpu torchvision==0.13.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install torch==1.13.0+cpu torchvision==0.14.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+ - if: runner.os == 'Linux'
+ name: Install itk pre-release (Linux only)
+ run: |
+ python -m pip install --pre -U itk
- name: Install the dependencies
run: |
- python -m pip install torch==1.12.0 torchvision==0.13.0
+ python -m pip install torch==1.13.0 torchvision==0.14.0
cat "requirements-dev.txt"
python -m pip install -r requirements-dev.txt
python -m pip list
@@ -117,13 +122,13 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
@@ -138,7 +143,7 @@ jobs:
# install the latest pytorch for testing
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
# fresh torch installation according to pyproject.toml
- python -m pip install torch>=1.7 torchvision
+ python -m pip install torch>=1.8 torchvision
- name: Check packages
run: |
pip uninstall monai
@@ -150,9 +155,9 @@ jobs:
python setup.py check -m -s
python setup.py sdist bdist_wheel
python -m twine check dist/*
- - run: echo "::set-output name=pwd::$PWD"
+ - run: echo "pwd=$PWD" >> $GITHUB_OUTPUT
id: root
- - run: echo "::set-output name=tmp_dir::$(mktemp -d)"
+ - run: echo "tmp_dir=$(mktemp -d)" >> $GITHUB_OUTPUT
id: mktemp
- name: Move packages
run: |
@@ -198,13 +203,13 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
@@ -222,5 +227,6 @@ jobs:
cd docs/
make clean
make html 2>&1 | tee tmp_log
+ if [[ $(grep -c "ERROR:" tmp_log) != 0 ]]; then echo "found errors"; grep "ERROR:" tmp_log; exit 1; fi
if [[ $(grep -c "WARNING:" tmp_log) != 0 ]]; then echo "found warnings"; grep "WARNING:" tmp_log; exit 1; fi
shell: bash
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index c79fb0a496f..b63fc40a4d5 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -13,13 +13,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: ['3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install setuptools
@@ -96,9 +96,9 @@ jobs:
with:
fetch-depth: 0
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- shell: bash
run: |
git describe
@@ -127,7 +127,7 @@ jobs:
name: _version.py
- name: Set tag
id: versioning
- run: echo ::set-output name=tag::${GITHUB_REF#refs/*/}
+ run: echo "tag=${GITHUB_REF#refs/*/}" >> $GITHUB_OUTPUT
- name: Check tag
env:
RELEASE_VERSION: ${{ steps.versioning.outputs.tag }}
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index 7cc4aec8b8b..b87e8e6b355 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -1,7 +1,9 @@
+# Jenkinsfile.monai-postmerge
name: deploy
on:
# full tests for all the important branches
+<<<<<<< HEAD
# push:
# branches:
# - main
@@ -10,6 +12,14 @@ on:
schedule:
- cron: "00 * * * *" # trigger per hour
workflow_dispatch:
+=======
+ push:
+ branches:
+ - main
+ - releasing/*
+ - feature/*
+ - dev
+>>>>>>> upstream/dev
concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
@@ -19,8 +29,9 @@ concurrency:
jobs:
# caching of these jobs:
# - docker-py3-pip- (shared)
- # - ubuntu py36 37 38-pip-
+ # - ubuntu 37 38 39 310-pip-
# - os-latest-pip (shared)
+<<<<<<< HEAD
# coverage-py3:
# if: github.repository == 'Project-MONAI/MONAI'
# container:
@@ -72,24 +83,79 @@ jobs:
# with:
# fail_ci_if_error: false
# file: ./coverage.xml
+=======
+ coverage-py3:
+ if: github.repository == 'Project-MONAI/MONAI'
+ container:
+ image: nvcr.io/nvidia/pytorch:22.08-py3 # CUDA 11.7
+ options: --gpus all
+ runs-on: [self-hosted, linux, x64, integration]
+ steps:
+ - uses: actions/checkout@v3
+ - name: cache weekly timestamp
+ id: pip-cache
+ run: |
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ - name: cache for pip
+ if: ${{ startsWith(github.ref, 'refs/heads/dev') }}
+ uses: actions/cache@v3
+ id: cache
+ with:
+ path: |
+ ~/.cache/pip
+ ~/.cache/torch
+ key: docker-py3-pip-${{ steps.pip-cache.outputs.datew }}
+ - name: Install the dependencies
+ run: |
+ which python
+ python -m pip install --upgrade pip wheel
+ python -m pip uninstall -y torch torchvision
+ rm -rf $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")/ruamel*
+ python -m pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 -f https://download.pytorch.org/whl/torch_stable.html
+ python -m pip install -r requirements-dev.txt
+ - name: Run unit tests report coverage
+ run: |
+ python -m pip list
+ git config --global --add safe.directory /__w/MONAI/MONAI
+ git clean -ffdx
+ df -h
+ # python -m pip cache info
+ nvidia-smi
+ export CUDA_VISIBLE_DEVICES=$(python -m tests.utils | tail -n 1)
+ echo $CUDA_VISIBLE_DEVICES
+ trap 'if pgrep python; then pkill python; fi;' ERR
+ python -c $'import torch\na,b=torch.zeros(1,device="cuda:0"),torch.zeros(1,device="cuda:1");\nwhile True:print(a,b)' > /dev/null &
+ python -c "import torch; print(torch.__version__); print('{} of GPUs available'.format(torch.cuda.device_count()))"
+ python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
+ BUILD_MONAI=1 ./runtests.sh --build --coverage --unittests --disttests # unit tests with coverage report
+ BUILD_MONAI=1 ./runtests.sh --build --coverage --net # integration tests with coverage report
+ coverage xml --ignore-errors
+ if pgrep python; then pkill python; fi
+ shell: bash
+ - name: Upload coverage
+ uses: codecov/codecov-action@v3
+ with:
+ fail_ci_if_error: false
+ files: ./coverage.xml
+>>>>>>> upstream/dev
test-py3x:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.7, 3.8, 3.9]
+ python-version: ['3.7', '3.8', '3.9', '3.10']
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: cache weekly timestamp
id: pip-cache
run: |
- echo "::set-output name=datew::$(date '+%Y-%V')"
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
- name: cache for pip
uses: actions/cache@v3
id: cache
@@ -105,22 +171,33 @@ jobs:
pwd
ls
python -m pip install --upgrade pip wheel
+<<<<<<< HEAD
python -m pip install torch==1.12.0 torchvision==0.13.0
python -m pip install -r requirements-min.txt
+=======
+ python -m pip install torch==1.13.0 torchvision==0.14.0
+ python -m pip install -r requirements-dev.txt
+>>>>>>> upstream/dev
- name: Run quick tests CPU ubuntu
run: |
python -m pip list
BUILD_MONAI=0 python setup.py develop
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
+<<<<<<< HEAD
python -c "import monai; monai.config.print_debug_info()"
QUICKTEST=True ./runtests.sh --build --min
coverage xml
+=======
+ BUILD_MONAI=1 ./runtests.sh --build --quick --unittests --disttests
+ coverage xml --ignore-errors
+>>>>>>> upstream/dev
- name: Upload coverage
- uses: codecov/codecov-action@v1
+ uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
- file: ./coverage.xml
+ files: ./coverage.xml
+<<<<<<< HEAD
# install: # pip install from github url, the default branch is dev
# runs-on: ubuntu-latest
# steps:
@@ -168,3 +245,52 @@ jobs:
# python -m tests.min_tests
# env:
# QUICKTEST: True
+=======
+ install: # pip install from github url, the default branch is dev
+ runs-on: ubuntu-latest
+ steps:
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.8'
+ - name: cache weekly timestamp
+ id: pip-cache
+ run: |
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ - name: cache for pip
+ uses: actions/cache@v3
+ id: cache
+ with:
+ path: |
+ ~/.cache/pip
+ ~/.cache/torch
+ key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
+ - name: Install the default branch no build (dev branch only)
+ if: github.ref == 'refs/heads/dev'
+ run: |
+ BUILD_MONAI=0 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
+ python -c 'import monai; monai.config.print_config()'
+ cd $(python -c 'import monai; import os; print(os.path.dirname(monai.__file__))')
+ ls .
+ pip uninstall -y monai
+ - name: Install the default branch with build (dev branch only)
+ if: github.ref == 'refs/heads/dev'
+ run: |
+ BUILD_MONAI=1 pip install git+https://github.com/Project-MONAI/MONAI#egg=MONAI
+ python -c 'import monai; monai.config.print_config()'
+ - name: Get the test cases (dev branch only)
+ if: github.ref == 'refs/heads/dev'
+ uses: actions/checkout@v3
+ with:
+ ref: dev
+ - name: Quick test installed (dev branch only)
+ if: github.ref == 'refs/heads/dev'
+ run: |
+ cd $GITHUB_WORKSPACE
+ rm -rf monai/
+ ls -al .
+ python -m pip install -r requirements-min.txt
+ python -m tests.min_tests
+ env:
+ QUICKTEST: True
+>>>>>>> upstream/dev
diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml
index 8fa69615c39..238302a0186 100644
--- a/.github/workflows/weekly-preview.yml
+++ b/.github/workflows/weekly-preview.yml
@@ -14,9 +14,9 @@ jobs:
ref: dev
fetch-depth: 0
- name: Set up Python 3.8
- uses: actions/setup-python@v3
+ uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: '3.8'
- name: Install setuptools
run: |
python -m pip install --user --upgrade setuptools wheel
@@ -33,7 +33,7 @@ jobs:
export YEAR_WEEK=$(date +'%y%U')
echo "Year week for tag is ${YEAR_WEEK}"
if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi
- git tag "0.10.dev${YEAR_WEEK}"
+ git tag "1.1.dev${YEAR_WEEK}"
git log -1
git tag --list
python setup.py sdist bdist_wheel
diff --git a/.gitignore b/.gitignore
index 9c0554dca93..3da001d0ce4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -131,7 +131,10 @@ tests/testing_data/MedNIST*
tests/testing_data/*Hippocampus*
tests/testing_data/*.tiff
tests/testing_data/schema.json
-*.svg
+tests/testing_data/endo.mp4
+tests/testing_data/ultrasound.avi
+tests/testing_data/train_data_stats.yaml
+tests/testing_data/eval_data_stats.yaml
# clang format tool
.clang-format-bin/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index d2a69908306..62550f51d4b 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -28,7 +28,7 @@ repos:
- id: mixed-line-ending
- repo: https://github.com/asottile/pyupgrade
- rev: v2.34.0
+ rev: v2.38.2
hooks:
- id: pyupgrade
args: [--py37-plus]
@@ -40,7 +40,7 @@ repos:
)$
- repo: https://github.com/asottile/yesqa
- rev: v1.3.0
+ rev: v1.4.0
hooks:
- id: yesqa
name: Unused noqa
@@ -58,7 +58,7 @@ repos:
)$
- repo: https://github.com/hadialqattan/pycln
- rev: v1.3.5
+ rev: v2.1.1
hooks:
- id: pycln
args: [--config=pyproject.toml]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4a8dcfc60d7..abc8f703d72 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,11 +1,92 @@
# Changelog
All notable changes to MONAI are documented in this file.
-The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
-and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
+The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
+## [1.0.1] - 2022-10-24
+### Fixes
+* DiceCELoss for multichannel targets
+* Auto3DSeg DataAnalyzer out-of-memory error and other minor issues
+* An optional flag issue in the RetinaNet detector
+* An issue with output offset for Spacing
+* A `LoadImage` issue when `track_meta` is `False`
+* 1D data output error in `VarAutoEncoder`
+* An issue with resolution computing in `ImageStats`
+### Added
+* Flexible min/max pixdim options for Spacing
+* Upsample mode `deconvgroup` and optional kernel sizes
+* Docstrings for gradient-based saliency maps
+* Occlusion sensitivity to use sliding window inference
+* Enhanced Gaussian window and device assignments for sliding window inference
+* Multi-GPU support for MonaiAlgo
+* `ClientAlgoStats` and `MonaiAlgoStats` for federated summary statistics
+* MetaTensor support for `OneOf`
+* Add a file check for bundle logging config
+* Additional content and an authentication token option for bundle info API
+* An anti-aliasing option for `Resized`
+* `SlidingWindowInferer` adaptive device based on `cpu_thresh`
+* `SegResNetDS` with deep supervision and non-isotropic kernel support
+* Premerge tests for Python 3.10
+### Changed
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.09-py3` from `nvcr.io/nvidia/pytorch:22.08-py3`
+* Replace `None` type metadata content with `"none"` for `collate_fn` compatibility
+* HoVerNet Mode and Branch to independent StrEnum
+* Automatically infer device from the first item in random elastic deformation dict
+* Add channel dim in `ComputeHoVerMaps` and `ComputeHoVerMapsd`
+* Remove batch dim in `SobelGradients` and `SobelGradientsd`
+### Deprecated
+* Deprecating `compute_meandice`, `compute_meaniou` in `monai.metrics`, in favor of
+`compute_dice` and `compute_iou` respectively
+
+## [1.0.0] - 2022-09-16
+### Added
+* `monai.auto3dseg` base APIs and `monai.apps.auto3dseg` components for automated machine learning (AutoML) workflow
+* `monai.fl` module with base APIs and `MonaiAlgo` for federated learning client workflow
+* An initial backwards compatibility [guide](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md#backwards-compatibility)
+* Initial release of accelerated MRI reconstruction components, including `CoilSensitivityModel`
+* Support of `MetaTensor` and new metadata attributes for various digital pathology components
+* Various `monai.bundle` enhancements for MONAI model-zoo usability, including config debug mode and `get_all_bundles_list`
+* new `monai.transforms` components including `SignalContinuousWavelet` for 1D signal, `ComputeHoVerMaps` for digital pathology, and `SobelGradients` for spatial gradients
+* `VarianceMetric` and `LabelQualityScore` metrics for active learning
+* Dataset API for real-time stream and videos
+* Several networks and building blocks including `FlexibleUNet` and `HoVerNet`
+* `MeanIoUHandler` and `LogfileHandler` workflow event handlers
+* `WSIReader` with the TiffFile backend
+* Multi-threading in `WSIReader` with cuCIM backend
+* `get_stats` API in `monai.engines.Workflow`
+* `prune_meta_pattern` in `monai.transforms.LoadImage`
+* `max_interactions` for deepedit interaction workflow
+* Various profiling utilities in `monai.utils.profiling`
+### Changed
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:22.08-py3` from `nvcr.io/nvidia/pytorch:22.06-py3`
+* Optionally depend on PyTorch-Ignite v0.4.10 instead of v0.4.9
+* The cache-based dataset now matches the transform information when read/write the cache
+* `monai.losses.ContrastiveLoss` now infers `batch_size` during `forward()`
+* Rearrange the spatial axes in `RandSmoothDeform` transforms following PyTorch's convention
+* Unified several environment flags into `monai.utils.misc.MONAIEnvVars`
+* Simplified `__str__` implementation of `MetaTensor` instead of relying on the `__repr__` implementation
+### Fixed
+* Improved error messages when both `monai` and `monai-weekly` are pip-installed
+* Inconsistent pseudo number sequences for different `num_workers` in `DataLoader`
+* Issue of repeated sequences for `monai.data.ShuffleBuffer`
+* Issue of not preserving the physical extent in `monai.transforms.Spacing`
+* Issue of using `inception_v3` as the backbone of `monai.networks.nets.TorchVisionFCModel`
+* Index device issue for `monai.transforms.Crop`
+* Efficiency issue when converting the array dtype and contiguous memory
+### Deprecated
+* `Addchannel` and `AsChannelFirst` transforms in favor of `EnsureChannelFirst`
+* `monai.apps.pathology.data` components in favor of the corresponding components from `monai.data`
+* `monai.apps.pathology.handlers` in favor of the corresponding components from `monai.handlers`
+### Removed
+* `Status` section in the pull request template in favor of the pull request draft mode
+* `monai.engines.BaseWorkflow`
+* `ndim` and `dimensions` arguments in favor of `spatial_dims`
+* `n_classes`, `num_classes` arguments in `AsDiscrete` in favor of `to_onehot`
+* `logit_thresh`, `threshold_values` arguments in `AsDiscrete` in favor of `threshold`
+* `torch.testing.assert_allclose` in favor of `tests.utils.assert_allclose`
+
## [0.9.1] - 2022-07-22
### Added
* Support of `monai.data.MetaTensor` as core data structure across the modules
@@ -562,7 +643,9 @@ the postprocessing steps should be used before calling the metrics methods
[highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md
-[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/0.9.1...HEAD
+[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.0.1...HEAD
+[1.0.1]: https://github.com/Project-MONAI/MONAI/compare/1.0.0...1.0.1
+[1.0.0]: https://github.com/Project-MONAI/MONAI/compare/0.9.1...1.0.0
[0.9.1]: https://github.com/Project-MONAI/MONAI/compare/0.9.0...0.9.1
[0.9.0]: https://github.com/Project-MONAI/MONAI/compare/0.8.1...0.9.0
[0.8.1]: https://github.com/Project-MONAI/MONAI/compare/0.8.0...0.8.1
diff --git a/CITATION.cff b/CITATION.cff
index a2ac1845f80..9b8a9f6ec51 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -6,8 +6,8 @@ title: "MONAI: Medical Open Network for AI"
abstract: "AI Toolkit for Healthcare Imaging"
authors:
- name: "MONAI Consortium"
-date-released: 2022-07-25
-version: "0.9.1"
+date-released: 2022-09-16
+version: "1.0.0"
identifiers:
- description: "This DOI represents all versions of MONAI, and will always resolve to the latest one."
type: doi
@@ -17,4 +17,125 @@ repository-code: "https://github.com/Project-MONAI/MONAI"
url: "https://monai.io"
cff-version: "1.2.0"
message: "If you use this software, please cite it using these metadata."
+preferred-citation:
+ type: article
+ authors:
+ - given-names: "M. Jorge"
+ family-names: "Cardoso"
+ - given-names: "Wenqi"
+ family-names: "Li"
+ - given-names: "Richard"
+ family-names: "Brown"
+ - given-names: "Nic"
+ family-names: "Ma"
+ - given-names: "Eric"
+ family-names: "Kerfoot"
+ - given-names: "Yiheng"
+ family-names: "Wang"
+ - given-names: "Benjamin"
+ family-names: "Murray"
+ - given-names: "Andriy"
+ family-names: "Myronenko"
+ - given-names: "Can"
+ family-names: "Zhao"
+ - given-names: "Dong"
+ family-names: "Yang"
+ - given-names: "Vishwesh"
+ family-names: "Nath"
+ - given-names: "Yufan"
+ family-names: "He"
+ - given-names: "Ziyue"
+ family-names: "Xu"
+ - given-names: "Ali"
+ family-names: "Hatamizadeh"
+ - given-names: "Andriy"
+ family-names: "Myronenko"
+ - given-names: "Wentao"
+ family-names: "Zhu"
+ - given-names: "Yun"
+ family-names: "Liu"
+ - given-names: "Mingxin"
+ family-names: "Zheng"
+ - given-names: "Yucheng"
+ family-names: "Tang"
+ - given-names: "Isaac"
+ family-names: "Yang"
+ - given-names: "Michael"
+ family-names: "Zephyr"
+ - given-names: "Behrooz"
+ family-names: "Hashemian"
+ - given-names: "Sachidanand"
+ family-names: "Alle"
+ - given-names: "Mohammad"
+ family-names: "Zalbagi Darestani"
+ - given-names: "Charlie"
+ family-names: "Budd"
+ - given-names: "Marc"
+ family-names: "Modat"
+ - given-names: "Tom"
+ family-names: "Vercauteren"
+ - given-names: "Guotai"
+ family-names: "Wang"
+ - given-names: "Yiwen"
+ family-names: "Li"
+ - given-names: "Yipeng"
+ family-names: "Hu"
+ - given-names: "Yunguan"
+ family-names: "Fu"
+ - given-names: "Benjamin"
+ family-names: "Gorman"
+ - given-names: "Hans"
+ family-names: "Johnson"
+ - given-names: "Brad"
+ family-names: "Genereaux"
+ - given-names: "Barbaros S."
+ family-names: "Erdal"
+ - given-names: "Vikash"
+ family-names: "Gupta"
+ - given-names: "Andres"
+ family-names: "Diaz-Pinto"
+ - given-names: "Andre"
+ family-names: "Dourson"
+ - given-names: "Lena"
+ family-names: "Maier-Hein"
+ - given-names: "Paul F."
+ family-names: "Jaeger"
+ - given-names: "Michael"
+ family-names: "Baumgartner"
+ - given-names: "Jayashree"
+ family-names: "Kalpathy-Cramer"
+ - given-names: "Mona"
+ family-names: "Flores"
+ - given-names: "Justin"
+ family-names: "Kirby"
+ - given-names: "Lee A.D."
+ family-names: "Cooper"
+ - given-names: "Holger R."
+ family-names: "Roth"
+ - given-names: "Daguang"
+ family-names: "Xu"
+ - given-names: "David"
+ family-names: "Bericat"
+ - given-names: "Ralf"
+ family-names: "Floca"
+ - given-names: "S. Kevin"
+ family-names: "Zhou"
+ - given-names: "Haris"
+ family-names: "Shuaib"
+ - given-names: "Keyvan"
+ family-names: "Farahani"
+ - given-names: "Klaus H."
+ family-names: "Maier-Hein"
+ - given-names: "Stephen"
+ family-names: "Aylward"
+ - given-names: "Prerna"
+ family-names: "Dogra"
+ - given-names: "Sebastien"
+ family-names: "Ourselin"
+ - given-names: "Andrew"
+ family-names: "Feng"
+ doi: "https://doi.org/10.48550/arXiv.2211.02701"
+ month: 11
+ year: 2022
+ title: "MONAI: An open-source framework for deep learning in healthcare"
...
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index ac373bad75f..07290936c5d 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -32,7 +32,7 @@ MONAI is part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/), and mainly
_Pull request early_
-We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. Change your pull request's title, to begin with `[WIP]` and/or [create a draft pull request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests#draft-pull-requests) until it is ready for formal review.
+We encourage you to create pull requests early. It helps us track the contributions under development, whether they are ready to be merged or not. [Create a draft pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/changing-the-stage-of-a-pull-request) until it is ready for formal review.
Please note that, as per PyTorch, MONAI uses American English spelling. This means classes and variables should be: normali**z**e, visuali**z**e, colo~~u~~r, etc.
@@ -193,7 +193,8 @@ Integration tests with minimal requirements are deployed to ensure this strategy
To add new optional dependencies, please communicate with the core team during pull request reviews,
and add the necessary information (at least) to the following files:
- [setup.cfg](https://github.com/Project-MONAI/MONAI/blob/dev/setup.cfg) (for package's `[options.extras_require]` config)
-- [docs/requirements.txt](https://github.com/Project-MONAI/MONAI/blob/dev/docs/requirements.txt) (pip requirements.txt file)
+- [requirements-dev.txt](https://github.com/Project-MONAI/MONAI/blob/dev/requirements-dev.txt) (pip requirements file)
+- [docs/requirements.txt](https://github.com/Project-MONAI/MONAI/blob/dev/docs/requirements.txt) (docs pip requirements file)
- [environment-dev.yml](https://github.com/Project-MONAI/MONAI/blob/dev/environment-dev.yml) (conda environment file)
- [installation.md](https://github.com/Project-MONAI/MONAI/blob/dev/docs/source/installation.md) (documentation)
@@ -283,10 +284,25 @@ for example, ``import monai.transforms.Spacing`` is the equivalent of ``monai.tr
For string definition, [f-string](https://www.python.org/dev/peps/pep-0498/) is recommended to use over `%-print` and `format-print`. So please try to use `f-string` if you need to define any string object.
#### Backwards compatibility
-MONAI is currently under active development, and with major version zero (following the [Semantic Versioning](https://semver.org/)).
-The backwards compatibility of the API is not always guaranteed at this initial development stage.
-However, utility functions are provided in the `monai.utils.deprecated` modules to help users migrate to the new API.
-The use of these functions is encouraged.
+MONAI in general follows [PyTorch's policy for backward compatibility](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy).
+Utility functions are provided in `monai.utils.deprecated` to help migrate from the deprecated to new APIs. The use of these utilities is encouraged.
+The pull request [template contains checkboxes](https://github.com/Project-MONAI/MONAI/blame/dev/.github/pull_request_template.md#L11-L12) that
+the contributor should use accordingly to clearly indicate breaking changes.
+
+The process of releasing backwards incompatible API changes is as follows:
+1. discuss the breaking changes during pull requests or in dev meetings with a feature proposal if needed.
+1. add a warning message in the upcoming release (version `X.Y`), the warning message should include a forecast of removing the deprecated API in:
+ 1. `X+1.0` -- major version `X+1` and minor version `0` the next major version if it's a significant change,
+ 1. `X.Y+2` -- major version `X` and minor version `Y+2` (the minor version after the next one), if it's a minor API change.
+ 1. Note that the versioning policy is similar to PyTorch's approach which does not precisely follow [the semantic versioning](https://semver.org/) definition.
+ Major version numbers are instead used to represent major product version (which is currently not planned to be greater than 1),
+ minor version for both compatible and incompatible, and patch version for bug fixes.
+ 1. when recommending new API to use in place of a deprecated API, the recommended version should
+ provide exact feature-like behaviour otherwise users will have a harder time migrating.
+1. add new test cases by extending the existing unit tests to cover both the deprecated and updated APIs.
+1. collect feedback from the users during the subsequent few releases, and reconsider step 1 if needed.
+1. before each release, review the deprecating APIs and relevant tests, and clean up the removed APIs described in step 2.
+
### Submitting pull requests
@@ -349,12 +365,11 @@ When major features are ready for a milestone, to prepare for a new release:
- Merge `releasing/[version number]` to `dev`, this step must make sure that the tagging commit unchanged on `dev`.
- Publish the release note.
-Note that the release should be tagged with a [PEP440](https://www.python.org/dev/peps/pep-0440/) compliant
-[semantic versioning](https://semver.org/spec/v2.0.0.html) number.
+Note that the release should be tagged with a [PEP440](https://www.python.org/dev/peps/pep-0440/) compliant version number.
If any error occurs during the release process, first check out a new hotfix branch from the `releasing/[version number]`,
then make PRs to the `releasing/[version number]` to fix the bugs via the regular contribution procedure.
If any error occurs after the release process, first check out a new hotfix branch from the `main` branch,
-make a minor version release following the semantic versioning, for example, `releasing/0.1.1`.
+make a patch version release following the semantic versioning, for example, `releasing/0.1.1`.
Make sure the `releasing/0.1.1` is merged back into both `dev` and `main` and all the test pipelines succeed.
diff --git a/Dockerfile b/Dockerfile
index ed35ac408d9..c41af38a579 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,7 +11,7 @@
# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
-ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.07-py3
+ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:22.10-py3
FROM ${PYTORCH_IMAGE}
LABEL maintainer="monai.contact@gmail.com"
diff --git a/README.md b/README.md
index 5701d1786e6..1d48da47252 100644
--- a/README.md
+++ b/README.md
@@ -18,8 +18,7 @@ Its ambitions are:
## Features
-> _The codebase is currently under active development._
-> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the current milestone release._
+> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
- flexible pre-processing for multi-dimensional medical imaging data;
- compositional & portable APIs for ease of integration in existing workflows;
@@ -46,6 +45,10 @@ Examples and notebook tutorials are located at [Project-MONAI/tutorials](https:/
Technical documentation is available at [docs.monai.io](https://docs.monai.io).
+## Model Zoo
+[The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community.
+Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI.
+
## Contributing
For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md).
diff --git a/docs/Makefile b/docs/Makefile
index d9a870064a2..5afe804955b 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -20,6 +20,7 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
+ PIP_ROOT_USER_ACTION=ignore pip install -r requirements.txt
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
clean:
diff --git a/docs/images/3d_paired.png b/docs/images/3d_paired.png
index dd751c8e164..811ba40c854 100644
Binary files a/docs/images/3d_paired.png and b/docs/images/3d_paired.png differ
diff --git a/docs/images/MONAI-logo-color.png b/docs/images/MONAI-logo-color.png
index a62aedd70e9..d1e8b6b7be2 100644
Binary files a/docs/images/MONAI-logo-color.png and b/docs/images/MONAI-logo-color.png differ
diff --git a/docs/images/UNETR.png b/docs/images/UNETR.png
index d028f26fd66..95d2c13c9ec 100644
Binary files a/docs/images/UNETR.png and b/docs/images/UNETR.png differ
diff --git a/docs/images/affine.png b/docs/images/affine.png
index a28bc79f76b..adf77cdc613 100644
Binary files a/docs/images/affine.png and b/docs/images/affine.png differ
diff --git a/docs/images/amp_training_a100.png b/docs/images/amp_training_a100.png
index d056adfc083..9c046b5b73c 100644
Binary files a/docs/images/amp_training_a100.png and b/docs/images/amp_training_a100.png differ
diff --git a/docs/images/amp_training_v100.png b/docs/images/amp_training_v100.png
index 844f777c065..4daf1691491 100644
Binary files a/docs/images/amp_training_v100.png and b/docs/images/amp_training_v100.png differ
diff --git a/docs/images/arch_modules.png b/docs/images/arch_modules.png
new file mode 100644
index 00000000000..f754926cfcc
Binary files /dev/null and b/docs/images/arch_modules.png differ
diff --git a/docs/images/arch_modules_v0.4.png b/docs/images/arch_modules_v0.4.png
deleted file mode 100644
index ec5a7d9d217..00000000000
Binary files a/docs/images/arch_modules_v0.4.png and /dev/null differ
diff --git a/docs/images/auto3dseg.png b/docs/images/auto3dseg.png
new file mode 100644
index 00000000000..b2a9da942cd
Binary files /dev/null and b/docs/images/auto3dseg.png differ
diff --git a/docs/images/blend.png b/docs/images/blend.png
index fdf8a21385a..a67a49adb44 100644
Binary files a/docs/images/blend.png and b/docs/images/blend.png differ
diff --git a/docs/images/blend_images.png b/docs/images/blend_images.png
index 55c415cc39e..d387a7f08fe 100644
Binary files a/docs/images/blend_images.png and b/docs/images/blend_images.png differ
diff --git a/docs/images/brats_distributed.png b/docs/images/brats_distributed.png
index 90877336eef..beaa7d7b160 100644
Binary files a/docs/images/brats_distributed.png and b/docs/images/brats_distributed.png differ
diff --git a/docs/images/cache_dataset.png b/docs/images/cache_dataset.png
index 85d9badd326..8ab8ef323ed 100644
Binary files a/docs/images/cache_dataset.png and b/docs/images/cache_dataset.png differ
diff --git a/docs/images/cam.png b/docs/images/cam.png
index 3a8dcfed1de..ad5018858aa 100644
Binary files a/docs/images/cam.png and b/docs/images/cam.png differ
diff --git a/docs/images/coplenet.png b/docs/images/coplenet.png
index 6d3da3b4679..97379117942 100644
Binary files a/docs/images/coplenet.png and b/docs/images/coplenet.png differ
diff --git a/docs/images/dataset_progress.png b/docs/images/dataset_progress.png
index de86de9e728..ad941440a16 100644
Binary files a/docs/images/dataset_progress.png and b/docs/images/dataset_progress.png differ
diff --git a/docs/images/datasets_speed.png b/docs/images/datasets_speed.png
index a960d93d313..2aa35b2597b 100644
Binary files a/docs/images/datasets_speed.png and b/docs/images/datasets_speed.png differ
diff --git a/docs/images/decollate_batch.png b/docs/images/decollate_batch.png
index 2a1c0c832c3..d353483cb74 100644
Binary files a/docs/images/decollate_batch.png and b/docs/images/decollate_batch.png differ
diff --git a/docs/images/deepedit.png b/docs/images/deepedit.png
index 6da60db34e3..56c880aff23 100644
Binary files a/docs/images/deepedit.png and b/docs/images/deepedit.png differ
diff --git a/docs/images/deepgrow.png b/docs/images/deepgrow.png
index d006bd0d090..dcef67608a0 100644
Binary files a/docs/images/deepgrow.png and b/docs/images/deepgrow.png differ
diff --git a/docs/images/deepgrow_scheme.png b/docs/images/deepgrow_scheme.png
index 9b4e4008398..a719c4b5770 100644
Binary files a/docs/images/deepgrow_scheme.png and b/docs/images/deepgrow_scheme.png differ
diff --git a/docs/images/detection.png b/docs/images/detection.png
index 44637ced7fb..76fa8f585ed 100644
Binary files a/docs/images/detection.png and b/docs/images/detection.png differ
diff --git a/docs/images/dints-overview.png b/docs/images/dints-overview.png
index e5f592a8e53..96bdc0278ec 100644
Binary files a/docs/images/dints-overview.png and b/docs/images/dints-overview.png differ
diff --git a/docs/images/end_to_end.png b/docs/images/end_to_end.png
deleted file mode 100644
index e837f64a930..00000000000
Binary files a/docs/images/end_to_end.png and /dev/null differ
diff --git a/docs/images/fast_training.png b/docs/images/fast_training.png
index 8dbd1e5b8d6..7f72aef4c0c 100644
Binary files a/docs/images/fast_training.png and b/docs/images/fast_training.png differ
diff --git a/docs/images/federated.svg b/docs/images/federated.svg
new file mode 100644
index 00000000000..1c598788770
--- /dev/null
+++ b/docs/images/federated.svg
@@ -0,0 +1,245 @@
+
+
diff --git a/docs/images/gmm_feature_set_comparison_s.png b/docs/images/gmm_feature_set_comparison_s.png
index a0161b81942..c14fbe849fe 100644
Binary files a/docs/images/gmm_feature_set_comparison_s.png and b/docs/images/gmm_feature_set_comparison_s.png differ
diff --git a/docs/images/invert_transforms.png b/docs/images/invert_transforms.png
index fa3863f3733..ac001913a28 100644
Binary files a/docs/images/invert_transforms.png and b/docs/images/invert_transforms.png differ
diff --git a/docs/images/lr_finder.png b/docs/images/lr_finder.png
index ed9ba697706..3fd72b233c5 100644
Binary files a/docs/images/lr_finder.png and b/docs/images/lr_finder.png differ
diff --git a/docs/images/matshow3d.png b/docs/images/matshow3d.png
index f71e69a99f0..75cf9ef29f3 100644
Binary files a/docs/images/matshow3d.png and b/docs/images/matshow3d.png differ
diff --git a/docs/images/medical_transforms.png b/docs/images/medical_transforms.png
index 7f405863c86..0f7eb69f7b3 100644
Binary files a/docs/images/medical_transforms.png and b/docs/images/medical_transforms.png differ
diff --git a/docs/images/metrics_report.png b/docs/images/metrics_report.png
index a317fcdc21d..c72cfca08d8 100644
Binary files a/docs/images/metrics_report.png and b/docs/images/metrics_report.png differ
diff --git a/docs/images/mil-patches.jpg b/docs/images/mil-patches.jpg
index fd904943be3..668ba31567f 100644
Binary files a/docs/images/mil-patches.jpg and b/docs/images/mil-patches.jpg differ
diff --git a/docs/images/models_ensemble.png b/docs/images/models_ensemble.png
index 907bbef4141..64e69b0fceb 100644
Binary files a/docs/images/models_ensemble.png and b/docs/images/models_ensemble.png differ
diff --git a/docs/images/mri_recon.png b/docs/images/mri_recon.png
new file mode 100644
index 00000000000..bbf775aa0c0
Binary files /dev/null and b/docs/images/mri_recon.png differ
diff --git a/docs/images/multi_transform_chains.png b/docs/images/multi_transform_chains.png
deleted file mode 100644
index fe7ee26e3fe..00000000000
Binary files a/docs/images/multi_transform_chains.png and /dev/null differ
diff --git a/docs/images/nsight_comparison.png b/docs/images/nsight_comparison.png
index 9b918265134..4505ecc153d 100644
Binary files a/docs/images/nsight_comparison.png and b/docs/images/nsight_comparison.png differ
diff --git a/docs/images/pathology-meta.png b/docs/images/pathology-meta.png
new file mode 100644
index 00000000000..9305b822a47
Binary files /dev/null and b/docs/images/pathology-meta.png differ
diff --git a/docs/images/postprocessing_transforms.png b/docs/images/postprocessing_transforms.png
index 8161ca496cb..950d4a6ac00 100644
Binary files a/docs/images/postprocessing_transforms.png and b/docs/images/postprocessing_transforms.png differ
diff --git a/docs/images/rand_gaussian_noise.png b/docs/images/rand_gaussian_noise.png
deleted file mode 100644
index a824ea8cc6e..00000000000
Binary files a/docs/images/rand_gaussian_noise.png and /dev/null differ
diff --git a/docs/images/sliding_window.png b/docs/images/sliding_window.png
index 3cd3aeea10b..644678f64bc 100644
Binary files a/docs/images/sliding_window.png and b/docs/images/sliding_window.png differ
diff --git a/docs/images/ssl_overview.png b/docs/images/ssl_overview.png
index 68fa1af5768..f9352352084 100644
Binary files a/docs/images/ssl_overview.png and b/docs/images/ssl_overview.png differ
diff --git a/docs/images/swin_unetr.png b/docs/images/swin_unetr.png
index 82683a1715c..883739db6f1 100644
Binary files a/docs/images/swin_unetr.png and b/docs/images/swin_unetr.png differ
diff --git a/docs/images/threaddataloader.png b/docs/images/threaddataloader.png
index 565df8d0d4c..03db9cc8303 100644
Binary files a/docs/images/threaddataloader.png and b/docs/images/threaddataloader.png differ
diff --git a/docs/images/transfer_mmar.png b/docs/images/transfer_mmar.png
index 7ae5b876ea1..e100d013344 100644
Binary files a/docs/images/transfer_mmar.png and b/docs/images/transfer_mmar.png differ
diff --git a/docs/images/tta.png b/docs/images/tta.png
index 6c4e18ffa04..cb4906cbb37 100644
Binary files a/docs/images/tta.png and b/docs/images/tta.png differ
diff --git a/docs/images/unet-pipe.png b/docs/images/unet-pipe.png
index 0f86ae22029..946ad73f5cb 100644
Binary files a/docs/images/unet-pipe.png and b/docs/images/unet-pipe.png differ
diff --git a/docs/images/workflows.png b/docs/images/workflows.png
index 858c566c763..44f6aa542e1 100644
Binary files a/docs/images/workflows.png and b/docs/images/workflows.png differ
diff --git a/docs/requirements.txt b/docs/requirements.txt
index d3a6ed95763..cb28412ad9e 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,6 +1,6 @@
--f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
-torch>=1.6
-pytorch-ignite==0.4.9
+-f https://download.pytorch.org/whl/cpu/torch-1.12.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
+torch>=1.8
+pytorch-ignite==0.4.10
numpy>=1.17
itk>=5.2
nibabel
@@ -20,7 +20,7 @@ sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
einops
-transformers
+transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
mlflow
tensorboardX
imagecodecs; platform_system == "Linux"
@@ -31,3 +31,6 @@ jsonschema
pynrrd
pydicom
h5py
+nni
+optuna
+opencv-python-headless
diff --git a/docs/source/api.rst b/docs/source/api.rst
index c2b19adeb2b..e16a19f4882 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -7,6 +7,8 @@ API Reference
:maxdepth: 1
apps
+ auto3dseg
+ fl
bundle
transforms
losses
diff --git a/docs/source/applications.md b/docs/source/applications.md
new file mode 100644
index 00000000000..5317a3d49a8
--- /dev/null
+++ b/docs/source/applications.md
@@ -0,0 +1,79 @@
+# Research and Application Highlights
+
+### COPLE-Net for COVID-19 Pneumonia Lesion Segmentation
+[A reimplementation](https://monai.io/research/coplenet-pneumonia-lesion-segmentation) of the COPLE-Net originally proposed by:
+
+G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. (2020) "A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images." IEEE Transactions on Medical Imaging. 2020. [DOI: 10.1109/TMI.2020.3000314](https://doi.org/10.1109/TMI.2020.3000314)
+![coplenet](../images/coplenet.png)
+
+### LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation
+[A reimplementation](https://monai.io/research/lamp-automated-model-parallelism) of the LAMP system originally proposed by:
+
+Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) "LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575)
+
+![LAMP UNet](../images/unet-pipe.png)
+
+### DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation
+MONAI integrated the `DiNTS` module to support more flexible topologies and joint two-level search. It provides a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap, and a cost usage aware search method which can search 3D networks with different GPU memory requirements. For more details, please check the [DiNTS tutorial](https://monai.io/research/dints.html).
+
+![DiNTS](../images/dints-overview.png)
+
+### Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging
+For [classification of digital pathology whole slide images (WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and network modules for multiple instance learning. These include self-attention transformer blocks for explicitly accounting of the dependencies between instances (image patches) during training. For more details, please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). ![multi-instance](../images/mil-patches.jpg)
+
+### Self-supervised representation learning
+MONAI starts to explore self-supervised representation learning in this milestone release. The Vision Transformer has been extended to learn from self-supervised reconstruction tasks with various data augmentation and a regularized contrastive loss. The weights of the pre-trained backbone could be used to enhance the performance of the novel downstream deep learning tasks.
+
+The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) shows how to generate a good set of pre-trained weights using unlabeled data with self-supervised tasks, then use the pre-trained weights to perform fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`.
+
+![self-supervised](../images/ssl_overview.png)
+
+### Swin UNETR model for the task of multi-organ segmentation
+For [Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images](https://arxiv.org/abs/2201.01266), MONAI introduces new network modules for multi-organ segmentation task using the BTCV challenge dataset. The architecture of Swin UNETR:
+
+![swin-unetr](../images/swin_unetr.png)
+
+The [tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb) shows a typical pipeline of multi-organ segmentation based on Swin UNETR model, DiceCE loss function, Mean Dice, etc. And we used weights from self-supervised pre-training of Swin UNETR encoder (3D Swin Transformer) on a cohort of 5050 CT scans from publicly available datasets.
+
+### DeepGrow modules for interactive segmentation
+[A reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components, which is deep learning based semi-automated segmentation approach that aims to be a "smart" interactive tool for region of interest delineation in medical images, originally proposed by:
+
+Sakinis, Tomas, et al. "Interactive segmentation of medical images through fully convolutional neural networks." arXiv preprint arXiv:1903.08205 (2019).
+
+![deepgrow scheme](../images/deepgrow.png)
+
+### DeepEdit workflow for interactive segmentation
+DeepEdit is a method that combines an automatic and a semi-automatic approach for 3D medical images into a single deep learning-based model. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/deepedit) of the DeepEdit modules provides essential components for interactive segmentation. More details are available in the training and inference [tutorial](https://github.com/Project-MONAI/tutorials/tree/main/deepedit/ignite).
+
+The following figure shows the typical workflow of interactive segmentation:
+
+![deepedit workflow](../images/deepedit.png)
+
+### NuClick modules for interactive nuclei segmentation
+NuClick is a CNN-based approach to speed up collecting annotations for microscopic objects requiring minimum interaction from the annotator. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/nuclick) contains essential components for the training and inference workflows of NuClick interactive nuclei segmentation.
+
+The following figure is example outputs of NuClick (annotator click inside the nucleus and the mask will be generated by CNN):
+
+![nuclick output](../images/nuclick.png)
+
+### Lesion detection in digital pathology
+[Implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components, which includes efficient whole slide imaging IO and several patch sampling methods with NVIDIA cuCIM library and SmartCache mechanism, FROC measurements for lesion and probabilistic post-processing for lesion detection.
+
+![digital pathology](../images/pathology.png)
+
+### Learning-based image registration
+Starting from v0.5.0, MONAI provides experimental features for building learning-based 2D/3D registration workflows. These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.
+
+The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:
+
+![3d registration](../images/3d_paired.png)
+
+### 2D and 3D detection workflow
+The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/detection) contains 2D and 3D bounding box detection components of `RetinaNet`, which includes:bounding box operations, hard negative sampler, and RetinaNet detectors.
+
+The following figure shows the detection training and inference workflows:
+
+![detection workflow](../images/detection.png)
+
+### Reproducing the state-of-the-art Kaggle competition solutions
+[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification
diff --git a/docs/source/apps.rst b/docs/source/apps.rst
index 3c5eed4002d..248813d6791 100644
--- a/docs/source/apps.rst
+++ b/docs/source/apps.rst
@@ -92,6 +92,10 @@ Applications
.. autoclass:: ProbMapProducer
:members:
+.. automodule:: monai.apps.pathology.losses.hovernet_loss
+.. autoclass:: HoVerNetLoss
+ :members:
+
.. automodule:: monai.apps.pathology.metrics
.. autoclass:: LesionFROC
:members:
@@ -126,6 +130,46 @@ Applications
.. autoclass:: TileOnGridd
:members:
+.. automodule:: monai.apps.pathology.transforms.post.array
+.. autoclass:: GenerateSuccinctContour
+ :members:
+.. autoclass:: GenerateInstanceContour
+ :members:
+.. autoclass:: GenerateInstanceCentroid
+ :members:
+.. autoclass:: GenerateInstanceType
+ :members:
+.. autoclass:: Watershed
+ :members:
+.. autoclass:: GenerateWatershedMask
+ :members:
+.. autoclass:: GenerateInstanceBorder
+ :members:
+.. autoclass:: GenerateDistanceMap
+ :members:
+.. autoclass:: GenerateWatershedMarkers
+ :members:
+
+.. automodule:: monai.apps.pathology.transforms.post.dictionary
+.. autoclass:: GenerateSuccinctContourd
+ :members:
+.. autoclass:: GenerateInstanceContourd
+ :members:
+.. autoclass:: GenerateInstanceCentroidd
+ :members:
+.. autoclass:: GenerateInstanceTyped
+ :members:
+.. autoclass:: Watershedd
+ :members:
+.. autoclass:: GenerateWatershedMaskd
+ :members:
+.. autoclass:: GenerateInstanceBorderd
+ :members:
+.. autoclass:: GenerateDistanceMapd
+ :members:
+.. autoclass:: GenerateWatershedMarkersd
+ :members:
+
`Detection`
-----------
@@ -191,6 +235,11 @@ Applications
`Reconstruction`
----------------
+FastMRIReader
+~~~~~~~~~~~~~
+.. autoclass:: monai.apps.reconstruction.fastmri_reader.FastMRIReader
+ :members:
+
`ConvertToTensorComplex`
~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: monai.apps.reconstruction.complex_utils.convert_to_tensor_complex
@@ -210,3 +259,10 @@ Applications
`ComplexConj`
~~~~~~~~~~~~~
.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj
+
+`auto3dseg`
+-----------
+
+.. automodule:: monai.apps.auto3dseg
+ :members:
+ :imported-members:
diff --git a/docs/source/auto3dseg.rst b/docs/source/auto3dseg.rst
new file mode 100644
index 00000000000..4130880caec
--- /dev/null
+++ b/docs/source/auto3dseg.rst
@@ -0,0 +1,10 @@
+:github_url: https://github.com/Project-MONAI/MONAI
+
+.. _auto3dseg:
+
+Auto3dseg
+=========
+
+.. automodule:: monai.auto3dseg
+ :members:
+ :imported-members:
diff --git a/docs/source/bundle.rst b/docs/source/bundle.rst
index fc64d40f5ad..0977788ded0 100644
--- a/docs/source/bundle.rst
+++ b/docs/source/bundle.rst
@@ -38,6 +38,9 @@ Model Bundle
.. autofunction:: ckpt_export
.. autofunction:: download
.. autofunction:: load
+.. autofunction:: get_all_bundles_list
+.. autofunction:: get_bundle_info
+.. autofunction:: get_bundle_versions
.. autofunction:: run
.. autofunction:: verify_metadata
.. autofunction:: verify_net_in_out
diff --git a/docs/source/bundle_intro.rst b/docs/source/bundle_intro.rst
index 3d71de24b60..c43093db969 100644
--- a/docs/source/bundle_intro.rst
+++ b/docs/source/bundle_intro.rst
@@ -8,3 +8,7 @@ Bundle
mb_specification
config_syntax.md
+
+Detailed bundle examples and get started tutorial: https://github.com/Project-MONAI/tutorials/tree/main/bundle
+
+A collection of medical imaging models in the MONAI Bundle format: https://github.com/Project-MONAI/model-zoo
diff --git a/docs/source/conf.py b/docs/source/conf.py
index db0ca11be36..ecc1b3ff598 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -40,6 +40,7 @@
"engines",
"data",
"apps",
+ "fl",
"bundle",
"config",
"handlers",
@@ -48,6 +49,7 @@
"utils",
"inferers",
"optimizers",
+ "auto3dseg",
]
diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md
index 6cd4e62ef36..7cd71b507f2 100644
--- a/docs/source/config_syntax.md
+++ b/docs/source/config_syntax.md
@@ -15,7 +15,7 @@ Content:
- [`@` to reference Python objects in configurations](#to-reference-python-objects-in-configurations)
- [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)
- [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)
- - [`_target_` (`_disabled_` and `_requires_`) to instantiate a Python object](#instantiate-a-python-object)
+ - [`_target_` (`_disabled_`, `_desc_`, and `_requires_`) to instantiate a Python object](#instantiate-a-python-object)
- [The command line interface](#the-command-line-interface)
- [Recommendations](#recommendations)
@@ -67,7 +67,7 @@ or additionally, tune the input parameters then instantiate the component:
BasicUNet features: (32, 32, 32, 64, 64, 64).
```
-For more details on the `ConfigParser` API, please see https://docs.monai.io/en/latest/bundle.html#config-parser.
+For more details on the `ConfigParser` API, please see [`monai.bundle.ConfigParser`](https://docs.monai.io/en/latest/bundle.html#config-parser).
## Syntax examples explained
@@ -141,17 +141,19 @@ This dictionary will be instantiated as a Pytorch object at runtime.
{
"component_name": {
"_target_": "my.module.Class",
+ "_desc_": "this is a customized class which also triggers 'cudnn_opt' reference",
"_requires_": "@cudnn_opt",
"_disabled_": "true"}
}
```
-_Description:_ `_requires_` and `_disabled_` are optional keys.
-`_requires_` specifies references (string starts with `@`) or
-Python expression that will be evaluated/instantiated before `_target_` object is instantiated.
-It is useful when the component does not explicitly depend on the other ConfigItems via
-its arguments, but requires the dependencies to be instantiated/evaluated beforehand.
-`_disabled_` specifies a flag to indicate whether to skip the instantiation.
+_Description:_ `_requires_`, `_disabled_`, and `_desc_` are optional keys.
+- `_requires_` specifies references (string starts with `@`) or
+ Python expression that will be evaluated/instantiated before `_target_` object is instantiated.
+ It is useful when the component does not explicitly depend on the other ConfigItems via
+ its arguments, but requires the dependencies to be instantiated/evaluated beforehand.
+- `_disabled_` specifies a flag to indicate whether to skip the instantiation.
+- `_desc_` can be used for providing free text descriptions.
## The command line interface
diff --git a/docs/source/data.rst b/docs/source/data.rst
index c1d54b723da..8cb27cd347b 100644
--- a/docs/source/data.rst
+++ b/docs/source/data.rst
@@ -153,12 +153,6 @@ PILReader
:members:
-FastMRIReader
-~~~~~~~~~~~~~
-.. autoclass:: monai.apps.reconstruction.fastmri_reader.FastMRIReader
- :members:
-
-
Image writer
------------
@@ -318,6 +312,12 @@ OpenSlideWSIReader
.. autoclass:: monai.data.OpenSlideWSIReader
:members:
+TiffFileWSIReader
+~~~~~~~~~~~~~~~~~
+.. autoclass:: monai.data.TiffFileWSIReader
+ :members:
+
+
Whole slide image datasets
--------------------------
@@ -340,3 +340,18 @@ Bounding box
------------
.. automodule:: monai.data.box_utils
:members:
+
+Video datasets
+--------------
+
+VideoDataset
+~~~~~~~~~~~~
+.. autoclass:: monai.data.video_dataset.VideoDataset
+
+VideoFileDataset
+~~~~~~~~~~~~~~~~
+.. autoclass:: monai.data.video_dataset.VideoFileDataset
+
+CameraDataset
+~~~~~~~~~~~~~
+.. autoclass:: monai.data.video_dataset.CameraDataset
diff --git a/docs/source/engines.rst b/docs/source/engines.rst
index 0cd40afb786..6d2694124b1 100644
--- a/docs/source/engines.rst
+++ b/docs/source/engines.rst
@@ -16,11 +16,6 @@ Workflows
.. currentmodule:: monai.engines
-`BaseWorkflow`
-~~~~~~~~~~~~~~
-.. autoclass:: BaseWorkflow
- :members:
-
`Workflow`
~~~~~~~~~~
.. autoclass:: Workflow
diff --git a/docs/source/fl.rst b/docs/source/fl.rst
new file mode 100644
index 00000000000..412063013ed
--- /dev/null
+++ b/docs/source/fl.rst
@@ -0,0 +1,28 @@
+:github_url: https://github.com/Project-MONAI/MONAI
+
+.. _fl:
+
+Federated Learning
+==================
+.. currentmodule:: monai.fl.client
+
+`Client Base Classes`
+---------------------
+
+.. autoclass:: BaseClient
+ :members:
+
+.. autoclass:: ClientAlgo
+ :members:
+
+.. autoclass:: ClientAlgoStats
+ :members:
+
+`MONAI Bundle Reference Implementations`
+----------------------------------------
+
+.. autoclass:: MonaiAlgo
+ :members:
+
+.. autoclass:: MonaiAlgoStats
+ :members:
diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst
index 0172529f40c..5b408cfa71e 100644
--- a/docs/source/handlers.rst
+++ b/docs/source/handlers.rst
@@ -41,6 +41,12 @@ Mean Dice metrics handler
:members:
+Mean IoU metric handler
+-----------------------
+.. autoclass:: MeanIoUHandler
+ :members:
+
+
ROC AUC metrics handler
-----------------------
.. autoclass:: ROCAUC
@@ -65,6 +71,12 @@ Surface distance metrics handler
:members:
+Panoptic Quality metrics handler
+--------------------------------
+.. autoclass:: PanopticQuality
+ :members:
+
+
Mean squared error metrics handler
----------------------------------
.. autoclass:: MeanSquaredError
@@ -95,6 +107,12 @@ Metric logger
:members:
+Logfile handler
+---------------
+.. autoclass:: LogfileHandler
+ :members:
+
+
Training stats handler
----------------------
.. autoclass:: StatsHandler
diff --git a/docs/source/highlights.md b/docs/source/highlights.md
deleted file mode 100644
index 4e00471d4a5..00000000000
--- a/docs/source/highlights.md
+++ /dev/null
@@ -1,573 +0,0 @@
-# Modules Overview
-
-MONAI aims at supporting deep learning in medical image analysis at multiple granularities.
-This figure shows a typical example of the end-to-end workflow:
-![an end to end workflow](../images/end_to_end.png)
-
-## MONAI architecture
-The design principle of MONAI is to provide flexible and light APIs for users with varying expertise.
-1. All the core components are independent modules, which can be easily integrated into any existing PyTorch program.
-2. Users can leverage the workflows in MONAI to quickly set up a robust training or evaluation program for research experiments.
-3. Rich examples and demos are provided to demonstrate the key features.
-4. Researchers contribute implementations based on the state-of-the-art for the latest research challenges, including COVID-19 image analysis, Model Parallel, etc.
-
-The overall architecture and modules are shown in the following figure:
-![architecture overview](../images/arch_modules_v0.4.png)
-The rest of this page provides more details for each module.
-
-* [Data I/O, processing and augmentation](#medical-image-data-i-o-processing-and-augmentation)
-* [Datasets and DataLoader](#datasets-and-dataloader)
-* [Loss functions](#losses)
-* [Optimizers](#optimizers)
-* [Network architectures](#network-architectures)
-* [Evaluation](#evaluation)
-* [Visualization](#visualization)
-* [Result writing](#result-writing)
-* [Workflows](#workflows)
-* [Bundle](#bundle)
-* [Research](#research)
-* [Performance optimization and GPU acceleration](#performance-optimization-and-gpu-acceleration)
-* [Applications](#applications)
-
-## Medical image data I/O, processing and augmentation
-Medical images require highly specialized methods for I/O, preprocessing, and augmentation. Medical images are often in specialized formats with rich meta-information, and the data volumes are often high-dimensional. These require carefully designed manipulation procedures. The medical imaging focus of MONAI is enabled by powerful and flexible image transformations that facilitate user-friendly, reproducible, optimized medical data pre-processing pipelines.
-
-### 1. Transforms support both Dictionary and Array format data
-- The widely used computer vision packages (such as ``torchvision``) focus on spatially 2D array image processing. MONAI provides more domain-specific transformations for both spatially 2D and 3D and retains the flexible transformation "compose" feature.
-- As medical image preprocessing often requires additional fine-grained system parameters, MONAI provides transforms for input data encapsulated in python dictionaries. Users can specify the keys corresponding to the expected data fields and system parameters to compose complex transformations.
-
-There is a rich set of transforms in six categories: Crop & Pad, Intensity, IO, Post-processing, Spatial, and Utilities. For more details, please visit [all the transforms in MONAI](https://docs.monai.io/en/latest/transforms.html).
-
-Almost all the transforms expect the input data to have a channel-first shape format: `[Channel dim, spatial dim 1, spatial dim 2, ...]`.
-Flexible [base APIs](https://github.com/Project-MONAI/MONAI/tree/dev/monai/transforms) are also provided. The `monai.transforms` module is
-easily extensible.
-
-### 2. Medical specific transforms
-MONAI aims at providing a comprehensive medical image specific
-transformations. These currently include, for example:
-- `LoadImage`: Load medical specific formats file from provided path
-- `Spacing`: Resample input image into the specified `pixdim`
-- `Orientation`: Change the image's orientation into the specified `axcodes`
-- `RandGaussianNoise`: Perturb image intensities by adding statistical noises
-- `NormalizeIntensity`: Intensity Normalization based on mean and standard deviation
-- `Affine`: Transform image based on the affine parameters
-- `Rand2DElastic`: Random elastic deformation and affine in 2D
-- `Rand3DElastic`: Random elastic deformation and affine in 3D
-
-[2D transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transforms_demo_2d.ipynb) shows the detailed usage of several MONAI medical image specific transforms.
-![2d transform examples](../images/medical_transforms.png)
-
-
-### 3. Transforms support both NumPy array and PyTorch Tensor (CPU or GPU accelerated)
-From MONAI v0.7 we introduced PyTorch `Tensor` based computation in transforms, many transforms already support both `NumPy array` and `Tensor` as input types and computational backends. To get the supported backends of every transform, please execute: `python monai/transforms/utils.py`.
-
-To accelerate the transforms, a common approach is to leverage GPU parallel-computation. Users can first convert input data into GPU Tensor by `ToTensor` or `EnsureType` transform, then the following transforms can execute on GPU based on PyTorch `Tensor` APIs.
-GPU transform tutorial is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
-
-### 4. Fused spatial transforms
-As medical image volumes are usually large (in multi-dimensional arrays), pre-processing performance affects the overall pipeline speed. MONAI provides affine transforms to execute fused spatial operations.
-
-For example:
-```py
-# create an Affine transform
-affine = Affine(
- rotate_params=np.pi/4,
- scale_params=(1.2, 1.2),
- translate_params=(200, 40),
- padding_mode='zeros',
-)
-# convert the image using bilinear interpolation
-new_img = affine(image, spatial_size=(300, 400), mode='bilinear')
-```
-Experiments and test results are available at [Fused transforms test](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/transform_speed.ipynb).
-
-Currently, all the geometric image transforms (Spacing, Zoom, Rotate, Resize, etc.) are designed based on the PyTorch native interfaces. So all of them support GPU acceleration via `GPU Tensor` operations for high performance.
-
-[Geometric transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/3d_image_transforms.ipynb) indicates the usage of affine transforms with 3D medical images.
-![3d transform examples](../images/affine.png)
-
-### 5. Randomly crop out batch images based on positive/negative ratio
-Medical image data volume may be too large to fit into GPU memory. A widely-used approach is to randomly draw small size data samples during training and run a “sliding window” routine for inference. MONAI currently provides general random sampling strategies including class-balanced fixed ratio sampling which may help stabilize the patch-based training process. A typical example is in [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb), which achieves the class-balanced sampling with `RandCropByPosNegLabel` transform.
-
-### 6. Deterministic training for reproducibility
-Deterministic training support is necessary and important for deep learning research, especially in the medical field. Users can easily set the random seed to all the random transforms in MONAI locally and will not affect other non-deterministic modules in the user's program.
-
-For example:
-```py
-# define a transform chain for pre-processing
-train_transforms = monai.transforms.Compose([
- LoadImaged(keys=['image', 'label']),
- RandRotate90d(keys=['image', 'label'], prob=0.2, spatial_axes=[0, 2]),
- ... ...
-])
-# set determinism for reproducibility
-train_transforms.set_random_state(seed=0)
-```
-Users can also enable/disable deterministic at the beginning of training program:
-```py
-monai.utils.set_determinism(seed=0, additional_settings=None)
-```
-
-### 7. Multiple transform chains
-To apply different transforms on the same data and concatenate the results, MONAI provides `CopyItems` transform to make copies of specified items in the data dictionary and `ConcatItems` transform to combine specified items on the expected dimension, and also provides `DeleteItems` transform to delete unnecessary items to save memory.
-
-Typical usage is to scale the intensity of the same image into different ranges and concatenate the results together.
-![multiple transform chains](../images/multi_transform_chains.png)
-
-### 8. Debug transforms with DataStats
-When transforms are combined with the "compose" function, it's not easy to track the output of a specific transform. To help debug errors in the composed transforms, MONAI provides utility transforms such as `DataStats` to print out intermediate data properties such as `data shape`, `value range`, `data value`, `Additional information`, etc. It's a self-contained transform and can be integrated into any transform chain.
-
-### 9. Post-processing transforms for model output
-MONAI also provides post-processing transforms for handling the model outputs. Currently, the transforms include:
-- Adding an activation layer (Sigmoid, Softmax, etc.).
-- Converting to discrete values (Argmax, One-Hot, Threshold value, etc), as below figure (b).
-- Splitting multi-channel data into multiple single channels.
-- Removing segmentation noise based on Connected Component Analysis, as below figure (c).
-- Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e).
-
-After decollating the batch data of model output and applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing.
-![post-processing transforms](../images/postprocessing_transforms.png)
-
-### 10. Integrate third-party transforms
-The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`.
-
-For more details, please check out the tutorial: [integrate 3rd party transforms into MONAI program](https://github.com/Project-MONAI/tutorials/blob/master/modules/integrate_3rd_party_transforms.ipynb).
-
-In digital pathology training, due to the immense burden of loading images, the CPU is preoccupied by loading images and cannot catch up with preparing the data. This causes the pipeline to become IO bound and results in under-utilization of GPU. To overcome this bottleneck, [cuCIM](https://github.com/rapidsai/cucim) has implemented an optimized version of several common transforms that we are using in digital pathology pipeline. These transforms are natively being run on GPU and act on CuPy arrays. MONAI provides `CuCIM` and `RandCuCIM` adapters to integrate the `cuCIM` library. For instance:
-```py
-RandCuCIM(name="color_jitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04)
-CuCIM(name="scale_intensity_range", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0)
-```
-It has shown a significant speed up in pathology training metastasis detection model.
-
-### 11. IO factory for medical image formats
-Many popular image formats exist in the medical domain, and they are quite different with rich metadata information. To easily handle different medical image formats in the same pipeline, [MONAI provides `LoadImage` transform](https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb), which can automatically choose image readers based on the supported suffixes and in the following priority order:
-- User-specified reader at runtime when calling this loader.
-- Registered readers from the latest to the first in the list.
-- Default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), (npz, npy -> NumpyReader), (others -> ITKReader).
-
-The `ImageReader` API is quite straightforward, users can easily extend it for their customized image readers.
-
-With these pre-defined image readers, MONAI can load images in formats: `NIfTI`, `DICOM`, `PNG`, `JPG`, `BMP`, `NPY/NPZ`, etc.
-
-### 12. Save transform data into NIfTI or PNG files
-To convert images into files or debug the transform chain, MONAI provides `SaveImage` transform. Users can inject this transform into the transform chain to save the results.
-
-### 13. Automatically ensure `channel-first` data shape
-Medical images have different shape formats. They can be `channel-last`, `channel-first` or even `no-channel`. We may, for example, want to load several `no-channel` images and stack them as `channel-first` data. To improve the user experience, MONAI provided an `EnsureChannelFirst` transform to automatically detect data shape according to the meta information and convert it to the `channel-first` format consistently.
-
-### 14. Invert spatial transforms and test-time augmentations
-It is often desirable to invert the previously applied spatial transforms (resize, flip, rotate, zoom, crop, pad, etc.) within the deep learning workflows, for example, to resume to the original imaging space after processing the image data in a normalized data space. Many spatial transforms are enhanced with an `inverse` operation since in v0.5. The [model inference tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py) shows a basic example.
-
-If the pipeline includes random transformations, users may want to observe the effect that these transformations have on the output. The typical approach is that we pass the same input through the transforms multiple times with different random realizations. Then use the inverse transforms to move all the results to a common space, and calculate the metrics. MONAI provided `TestTimeAugmentation` for this feature, which by default will calculate the `mode`, `mean`, `standard deviation` and `volume variation coefficient`.
-
-[Invert transforms and TTA tutorials](https://github.com/Project-MONAI/tutorials/blob/master/modules/inverse_transforms_and_test_time_augmentations.ipynb) introduce details about the API with usage examples.
-
-(1) The last column is the inverted data of model output:
-
-![invert transform](../images/invert_transforms.png)
-
-(2) The TTA results of `mode`, `mean` and `standard deviation`:
-
-![test time augmentation](../images/tta.png)
-
-### 15. Visualization of transform examples
-To help clearly introduce the transform functionalities, MONAI provides visualization examples in the [API document](https://docs.monai.io/en/latest/transforms.html) for almost every transform, including spatial transforms, intensity transforms, crop / pad transforms, etc.
-
-For example:
-
-![rand gaussian noise](../images/rand_gaussian_noise.png)
-
-## Datasets and DataLoader
-### 1. Cache IO and transforms data to accelerate training
-Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large.
-
-MONAI provides a multi-thread `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb).
-
-![digital pathology](../images/cache_dataset.png)
-
-### 2. Cache intermediate outcomes into persistent storage
-The `PersistentDataset` is similar to the CacheDataset, where the intermediate cache values are persisted to disk storage or LMDB for rapid retrieval between experimental runs (as is the case when tuning hyperparameters), or when the entire data set size exceeds available memory. The `PersistentDataset` could achieve similar performance when comparing to `CacheDataset` in [Datasets experiment](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb).
-
-![cachedataset speed](../images/datasets_speed.png)
-
-### 3. SmartCache mechanism for big datasets
-During training with large volume dataset, an efficient approach is to only train with a subset of the dataset in an epoch and dynamically replace part of the subset in every epoch. It's the `SmartCache` mechanism in [NVIDIA Clara-train SDK](https://docs.nvidia.com/clara/tlt-mi/clara-train-sdk-v3.0/nvmidl/additional_features/smart_cache.html#smart-cache).
-
-MONAI provides a PyTorch version `SmartCache` as `SmartCacheDataset`. In each epoch, only the items in the cache are used for training, at the same time, another thread is preparing replacement items by applying the transform sequence to items not in the cache. Once one epoch is completed, `SmartCache` replaces the same number of items with replacement items.
-
-For example, if we have 5 images: `[image1, image2, image3, image4, image5]`, and `cache_num=4`, `replace_rate=0.25`. So the actual training images cached and replaced for every epoch are as below:
-```
-epoch 1: [image1, image2, image3, image4]
-epoch 2: [image2, image3, image4, image5]
-epoch 3: [image3, image4, image5, image1]
-epoch 3: [image4, image5, image1, image2]
-epoch N: [image[N % 5] ...]
-```
-Full example of `SmartCacheDataset` is available at [Distributed training with SmartCache](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/unet_training_smartcache.py).
-
-### 4. Zip multiple PyTorch datasets and fuse the output
-MONAI provides `ZipDataset` to associate multiple PyTorch datasets and combine the output data (with the same corresponding batch index) into a tuple, which can be helpful to execute complex training processes based on various data sources.
-
-For example:
-```py
-class DatasetA(Dataset):
- def __getitem__(self, index: int):
- return image_data[index]
-
-class DatasetB(Dataset):
- def __getitem__(self, index: int):
- return extra_data[index]
-
-dataset = ZipDataset([DatasetA(), DatasetB()], transform)
-```
-
-### 5. PatchDataset
-`monai.data.PatchDataset` provides a flexible API to combine both image- and patch-level preprocessing:
-```python
-image_dataset = Dataset(input_images, transforms=image_transforms)
-patch_dataset = PatchDataset(
- dataset=image_dataset, patch_func=sampler,
- samples_per_image=n_samples, transform=patch_transforms)
-```
-It supports user-specified `image_transforms` and `patch_transforms` with customisable patch sampling strategies,
-which decouples the two-level computations in a multiprocess context.
-
-### 6. Predefined Datasets for public medical data
-To quickly get started with popular training data in the medical domain, MONAI provides several data-specific Datasets(like: `MedNISTDataset`, `DecathlonDataset`, etc.), which include downloading from our AWS storage, extracting data files and support generation of training/evaluation items with transforms. And they are flexible in that users can easily modify the JSON config file to change the default behaviors.
-
-MONAI always welcome new contributions of public datasets, please refer to existing Datasets and leverage the download and extracting APIs, etc. [Public datasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb) indicates how to quickly set up training workflows with `MedNISTDataset` and `DecathlonDataset` and how to create a new `Dataset` for public data.
-
-The common workflow of predefined datasets:
-
-![pre-defined dataset](../images/dataset_progress.png)
-
-### 7. Partition dataset for cross validation
-The `partition_dataset` utility in MONAI can perform different types of partitioning for training and validation or cross-validation. It supports shuffling based on a specified random seed, and will return a set of datasets, each dataset contains one partition. And it can split the dataset based on specified ratios or evenly split into `num_partitions`. For given class labels, it can also make sure the same ratio of classes in every partition.
-
-### 8. CSV `Dataset` and `IterableDataset`
-CSV tables are often used in additional to image data to incorporate adjunct information, such as patient demographics, lab results, image acquisition parameters and other non-image data, MONAI provides `CSVDataset` to load CSV files and `CSVIterableDataset` to load large CSV files with scalable data access.
-In addition to the regular preprocessing transform while loading, it also supports multiple CSV files loading, joining tables, rows and columns selection and grouping. [CSVDatasets tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/csv_datasets.ipynb) shows detailed usage examples.
-
-### 9. `ThreadDataLoader` vs. `DataLoader`
-If the transforms are light-weighted, especially when we cache all the data in RAM, the multiprocessing of PyTorch `DataLoader` may cause unnecessary IPC time and cause the drop of GPU utilization after every epoch. MONAI provides `ThreadDataLoader` which executes the transforms in a separate thread:
-
-![threaddataloader](../images/threaddataloader.png)
-
-a `ThreadDataLoader` example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
-
-## Losses
-There are domain-specific loss functions in the medical imaging research which are not typically used in generic computer vision tasks. As an important module of MONAI, these loss functions are implemented in PyTorch, such as `DiceLoss`, `GeneralizedDiceLoss`, `MaskedDiceLoss`, `TverskyLoss`, `FocalLoss`, `DiceCELoss`, and `DiceFocalLoss`, etc.
-
-## Optimizers
-MONAI provides several advanced features in optimizers to help accelerate the training or fine-tuning progress. For example, `Novograd` optimizer can be used to converge faster than the traditional optimizers. And users can easily define different learning rates for the model layers based [on the `generate_param_groups` utility API](https://github.com/Project-MONAI/tutorials/blob/master/modules/layer_wise_learning_rate.ipynb).
-
-Another important feature is `LearningRateFinder`. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what the optimal learning rates are. [LearningRateFinder tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/learning_rate.ipynb) indicates the API usage examples.
-
-![learning rate finder plot](../images/lr_finder.png)
-
-## Network architectures
-Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks. MONAI implements reference networks with the aims of both flexibility and code readability.
-
-### 1. Predefined layers and blocks
-To leverage the common network layers and blocks, MONAI provides several predefined layers and blocks which are compatible with 1D, 2D and 3D networks. Users can easily integrate the layer factories in their customised networks.
-
-For example:
-```py
-# import MONAI’s layer factory
-from monai.networks.layers import Conv
-
-# adds a transposed convolution layer to the network
-# which is compatible with different spatial dimensions.
-name, dimension = Conv.CONVTRANS, 3
-conv_type = Conv[name, dimension]
-add_module('conv1', conv_type(in_channels, out_channels, kernel_size=1, bias=False))
-```
-
-### 2. Implementation of generic 2D/3D networks
-And there are several 1D/2D/3D-compatible implementations of intermediate blocks and generic networks, such as UNet, DynUNet, DenseNet, GAN, AHNet, VNet, SENet(and SEResNet, SEResNeXt), SegResNet, EfficientNet, Attention-based transformer networks, Multi-instance learning networks, DiNTS for AutoML, etc. All the networks can support PyTorch serialization pipeline based on `torch.jit.script`.
-
-### 3. Network adapter to finetune final layers
-Instead of training from scratch, we often leverage the existing models, and finetune the final layers of a network for new learning tasks. MONAI provides a `NetAdapter` to easily replace the last layer of a model by a convolutional layer or a fully-connected layer. A typical usage example is to adapt [Torchvision models trained with ImageNet](https://pytorch.org/vision/stable/models.html) for other learning tasks.
-
-## Evaluation
-To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included:
-
-### 1. Sliding window inference
-For model inferences on large volumes, the sliding window approach is a popular choice to achieve high performance while having flexible memory requirements (_alternatively, please check out the latest research on [model parallel training](#lamp-large-deep-nets-with-automated-model-parallelism-for-image-segmentation) using MONAI_). It also supports `overlap` and `blending_mode` configurations to handle the overlapped windows for better performances.
-
-A typical process is:
-1. Select continuous windows on the original image.
-2. Iteratively run batched window inferences until all windows are analyzed.
-3. Aggregate the inference outputs to a single segmentation map.
-4. Save the results to file or compute some evaluation metrics.
-![sliding window scheme](../images/sliding_window.png)
-
-The [Spleen 3D segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb) leverages `SlidingWindow` inference for validation.
-
-### 2. Metrics for medical tasks
-Various useful evaluation metrics have been used to measure the quality of medical image specific models. MONAI already implemented many medical domain-specific metrics, such as: `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff Distance`, `Surface Distance`, `Occlusion Sensitivity`.
-
-For example, `Mean Dice` score can be used for segmentation tasks, and the area under the ROC curve(`ROCAUC`) for classification tasks. We continue to integrate more options.
-
-1. MONAI provides flexible base APIs for metrics
-The base classes of MONAI metrics implement the basic computation logic for both iteration and epoch-based metrics. They are a good starting point for customized metrics.
-2. All the metrics support data parallel computation
-With a `Cumulative` base class, intermediate metric outcomes can be automatically buffered, cumulated, synced across distributed processes, and aggregated for the final results. [Multi-processing computation example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py) shows how to compute metrics based on saved predictions and labels in multi-processing environment.
-3. All the metrics modules can handle `batch-first` Tensors and list of `channel-first` Tensors
-
-### 3. Metrics report generation
-During evaluation, users usually save the metrics of every input image, then analyze the bad cases to improve the deep learning pipeline. To save detailed information of metrics, MONAI provided a handler `MetricsSaver`, which can save the final metric values, raw metric of every model output channel of every input image, metrics summary report of operations: `mean`, `median`, `max`, `min`, `percentile`, `std`, etc. The `MeanDice` reports of validation with prostate dataset are as below:
-
-![metrics report example](../images/metrics_report.png)
-
-## Visualization
-Beyond the simple point and curve plotting, MONAI provides intuitive interfaces to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb).
-
-To easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example:
-`matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap="gray")`
-
-![matshow3d example](../images/matshow3d.png)
-
-MONAI also provides the `blend_images` utility to blend the `image` and `label` to an RGB color image to better visualize the segmentation regions with the specified `cmap` mode and weights, etc. Showing a spleen segmentation `image` and the corresponding `label` as example:
-
-![blend example](../images/blend.png)
-
-For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualization tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb).
-
-And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models:
-
-![CAM visualization example](../images/cam.png)
-
-The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/master/modules/interpretability).
-
-## Result writing
-Currently, MONAI supports writing the model outputs as NIfTI files or PNG files for segmentation tasks, and as CSV files for classification tasks. And the writers can restore the data spacing, orientation or shape according to the `original_shape` or `original_affine` information from the input image.
-
-MONAI provides `SaveImage` transform to write the data volumes of image or prediction with automatically chose image writers based on the specified suffix, writers like: `NibabelWriter`, `ITKWriter`, `PILWriter`.
-The `ImageWriter` API can be easily extended it for customized image writing modules.
-
-MONAI also supports to save the relevant statistics and evaluation metrics details automatically computed from the evaluation process.
-
-## Workflows
-To quickly set up training and evaluation experiments, MONAI provides a set of workflows to significantly simplify the modules and allow for fast prototyping.
-
-These features decouple the domain-specific components and the generic machine learning processes. They also provide a set of unify APIs for higher level applications (such as AutoML, Federated Learning).
-The trainers and evaluators of the workflows are compatible with pytorch-ignite `Engine` and `Event-Handler` mechanism. There are rich event handlers in MONAI to independently attach to the trainer or evaluator, and users can register additional `custom events` to workflows.
-
-### 1. General workflows pipeline
-The workflow and some of MONAI event handlers are shown as below:
-![workflow pipeline](../images/workflows.png)
-
-The end-to-end training and evaluation examples are available at [Workflow examples](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines).
-
-### 2. EnsembleEvaluator
-Models ensemble is a popular strategy in machine learning and deep learning areas to achieve more accurate and more stable outputs. A typical practice is:
-1. Split all the training dataset into K folds.
-2. Train K models with every K-1 folds data.
-3. Execute inference on the test data with all the K models.
-4. Compute the average values with weights or vote the most common value as the final result.
-
-![model ensemble](../images/models_ensemble.png)
-More details of practice is at [Cross validation and model ensemble tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/cross_validation_models_ensemble.ipynb).
-
-
-### 3. Transfer learning for different input / output classes
-`Transfer-learning` is a common and efficient training approach, especially in the medical-specific domain where obtaining large datasets for training can be difficult. So transfer learning from a pre-trained checkpoint can significantly improve the model metrics and shorten training time.
-
-MONAI provided `CheckpointLoader` to load a checkpoint for the workflow before training, and it allows some `layer names` of the current network don't match the checkpoint, or some `layer shapes` don't match the checkpoint, which can be useful if the current task has different input image classes or output classes.
-
-### 4. Transfer learning based on NVIDIA Clara MMAR
-[The MMAR (Medical Model ARchive)](https://docs.nvidia.com/clara/clara-train-sdk/pt/mmar.html) defines a data structure for organizing all artifacts produced during the model development life cycle. NVIDIA Clara provides rich existing MMARs of medical domain-specific models. And these MMARs include all the information about the model including configurations and scripts to provide a work space to perform all model development tasks. To better leverage the pretrained MMARs released on Nvidia GPU cloud, MONAI provides pythonic APIs to access the MMARs.
-
-The following figure compares the loss curves and validation scores for (1) training from scratch (the green line), (2) applying a pretrained model without training (the magenta line), (3) training from the pretrained model (the blue line), according to the number of training epochs
-(the tutorial is available at [transfer_mmar](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb)):
-
-![transfer_mmar](../images/transfer_mmar.png)
-
-### 5. Decollate batch data for flexible postprocessings
-`decollate batch` is introduced in MONAI v0.6, which simplifies the post-processing transforms and provides flexible following operations on a batch of data with various data shapes. It can decollate batched data (e.g. model predictions) into a list of tensors, for the benefits such as:
-1. enabling postprocessing transforms for each item independently -- randomised transforms could be applied differently for each predicted item in a batch.
-2. simplifying the transform APIs and reducing the input validation burdens because both the preprocessing and postprocessing transforms now only need to support the "channel-first" input format.
-3. enabling the `Invertd` transform for the predictions and the inverted data with different shapes, as the data items are in a list, not stacked in a single tensor.
-4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation.
-
-A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):
-![decollate_batch](../images/decollate_batch.png)
-
-[decollate batch tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb) shows a detailed usage example based on a PyTorch native workflow.
-
-### 6. Easy to integrate into popular workflows
-Except for the pytorch-ignite based `monai.engines`, most of the MONAI modules could be used independently or combined with other software packages. For example, MONAI can be easily integrated into popular frameworks such as PyTorch-Lightning and Catalyst: [Lightning segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d_lightning.ipynb) and [Lightning + TorchIO](https://github.com/Project-MONAI/tutorials/blob/master/modules/TorchIO_MONAI_PyTorch_Lightning.ipynb) tutorials show the PyTorch Lightning programs with MONAI modules, and [Catalyst segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_catalyst.ipynb) shows the Catalyst program with MONAI modules.
-
-## Bundle
-The objective of a MONAI bundle is to define a packaged model which includes the critical information necessary to allow users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON files are included to store metadata about the model, information for constructing training, inference, and post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes to include. More details are available at [bundle specification](https://docs.monai.io/en/latest/mb_specification.html).
-
-The key benefits of bundle are to define the model package and support building Python-based workflows via structured configurations:
-- Self-contained model package include all the necessary information.
-- Structured config can be used to easily reconstruct or prototype deep learning workflows.
-- Config files can provide good readability and usability by separating parameter settings from the Python code.
-- Config files can describe flexible workflow and components, allows for different low-level Python implementations
-- Learning paradigms at a higher level such as federated learning and AutoML can be decoupled from the component details.
-
-A typical bundle example can include:
-```
- ModelName
- ┣━ configs
- ┃ ┗━ metadata.json
- ┣━ models
- ┃ ┣━ model.pt
- ┃ ┣━ *model.ts
- ┃ ┗━ *model.onnx
- ┗━ docs
- ┣━ *README.md
- ┗━ *license.txt
-```
-Details about the bundle config definition and syntax & examples are at [config syntax](https://docs.monai.io/en/latest/config_syntax.html).
-A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/master/modules/bundles/get_started.ipynb) tutorial notebook can help users quickly set up a bundle.
-
-And [bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/modules/bundle) provides more real-world examples and advanced features of bundle, including a bundle example for 3D segmentation of the spleen from CT image, use cases of bringing customized python components, parsing the config files in your own python program, etc.
-
-## Research
-There are several research prototypes in MONAI corresponding to the recently published papers that address advanced research problems.
-We always welcome contributions in forms of comments, suggestions, and code implementations.
-
-The generic patterns/modules identified from the research prototypes will be integrated into MONAI core functionality.
-
-### 1. COPLE-Net for COVID-19 Pneumonia Lesion Segmentation
-[A reimplementation](https://monai.io/research/coplenet-pneumonia-lesion-segmentation) of the COPLE-Net originally proposed by:
-
-G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. (2020) "A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images." IEEE Transactions on Medical Imaging. 2020. [DOI: 10.1109/TMI.2020.3000314](https://doi.org/10.1109/TMI.2020.3000314)
-![coplenet](../images/coplenet.png)
-
-### 2. LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation
-[A reimplementation](https://monai.io/research/lamp-automated-model-parallelism) of the LAMP system originally proposed by:
-
-Wentao Zhu, Can Zhao, Wenqi Li, Holger Roth, Ziyue Xu, and Daguang Xu (2020) "LAMP: Large Deep Nets with Automated Model Parallelism for Image Segmentation." MICCAI 2020 (Early Accept, paper link: https://arxiv.org/abs/2006.12575)
-
-![LAMP UNet](../images/unet-pipe.png)
-
-### 3. DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation
-MONAI integrated the `DiNTS` module to support more flexible topologies and joint two-level search. It provides a topology guaranteed discretization algorithm and a discretization aware topology loss for the search stage to minimize the discretization gap, and a cost usage aware search method which can search 3D networks with different GPU memory requirements. For more details, please check the [DiNTS tutorial](https://monai.io/research/dints.html).
-
-![DiNTS](../images/dints-overview.png)
-
-### 4. Accounting for Dependencies in Deep Learning Based Multiple Instance Learning for Whole Slide Imaging
-For [classification of digital pathology whole slide images (WSI)](https://arxiv.org/abs/2111.01556), MONAI introduces new transforms and network modules for multiple instance learning. These include self-attention transformer blocks for explicitly accounting of the dependencies between instances (image patches) during training. For more details, please check out the [multiple instance learning tutorial](https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning). ![multi-instance](../images/mil-patches.jpg)
-
-### 5. Self-supervised representation learning
-MONAI starts to explore self-supervised representation learning in this milestone release. The Vision Transformer has been extended to learn from self-supervised reconstruction tasks with various data augmentation and a regularized contrastive loss. The weights of the pre-trained backbone could be used to enhance the performance of the novel downstream deep learning tasks.
-
-The [tutorial](https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining) shows how to generate a good set of pre-trained weights using unlabeled data with self-supervised tasks, then use the pre-trained weights to perform fine-tuning on a fully supervised volumetric segmentation task using a transformer based `UNETR`.
-
-![self-supervised](../images/ssl_overview.png)
-
-### 6. Swin UNETR model for the task of multi-organ segmentation
-For [Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images](https://arxiv.org/abs/2201.01266), MONAI introduces new network modules for multi-organ segmentation task using the BTCV challenge dataset. The architecture of Swin UNETR:
-
-![swin-unetr](../images/swin_unetr.png)
-
-The [tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb) shows a typical pipeline of multi-organ segmentation based on Swin UNETR model, DiceCE loss function, Mean Dice, etc. And we used weights from self-supervised pre-training of Swin UNETR encoder (3D Swin Transformer) on a cohort of 5050 CT scans from publicly available datasets.
-
-## Performance optimization and GPU acceleration
-Typically, model training is a time-consuming step during deep learning development, especially in medical imaging applications. Volumetric medical images are usually large (as multi-dimensional arrays) and the model training process can be complex. Even with powerful hardware (e.g. CPU/GPU with large RAM), it is not easy to fully leverage them to achieve high performance. MONAI provides a fast training guide to achieve the best performance: https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.
-
-NVIDIA GPUs have been widely applied in many areas of deep learning training and evaluation, and the CUDA parallel computation shows obvious acceleration when comparing to traditional computation methods. To fully leverage GPU features, many popular mechanisms raised, like automatic mixed precision (AMP), distributed data parallel, etc. MONAI can support these features and provides rich examples.
-
-### 1. Profiling the pipelines
-First of all, MONAI provides several methods based on `DLProf`, `Nsight`, `NVTX` and `NVML` for users to analyze their programs to identify the performance bottleneck. The analyses include operation-based GPU activity and overall GPU activity during model training. They will greatly help users manage computing bottlenecks and provide insights for the area to be improved for better computing efficiency. The detailed example is shown in the performance profiling tutorial: https://github.com/Project-MONAI/tutorials/blob/master/performance_profiling/radiology/profiling_train_base_nvtx.md.
-
-### 2. Auto mixed precision(AMP)
-In 2017, NVIDIA researchers developed a methodology for mixed-precision training, which combined single-precision (FP32) with half-precision (e.g. FP16) format when training a network, and it achieved the same accuracy as FP32 training using the same hyperparameters.
-
-For the PyTorch 1.6 release, developers at NVIDIA and Facebook moved mixed precision functionality into PyTorch core as the AMP package, `torch.cuda.amp`.
-
-MONAI workflows can easily set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP. And we tried to compare the training speed if AMP ON/OFF on NVIDIA V100 GPU with CUDA 11 and PyTorch 1.6, obtained some benchmark results:
-![amp v100 results](../images/amp_training_v100.png)
-We also executed the same test program on NVIDIA A100 GPU with the same software environment, obtained faster results:
-![amp a100 results](../images/amp_training_a100.png)
-More details is available at [AMP training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb).
-We also tried to combine `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, `DiceCE` loss, tuning of network and optimizer, to achieve the fast training in MONAI, with a V100 GPU and the target validation mean dice = 0.94 of the foreground channel only, it's more than `100x` speedup compared with the Pytorch regular implementation when achieving the same metric. And every epoch is 20x faster than regular training. Benchmark for reference:
-![fast training results](../images/fast_training.png)
-More details is available at [Fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
-
-### 3. Distributed data parallel
-Distributed data parallel is an important feature of PyTorch to connect multiple GPU devices on single or multiple nodes to train or evaluate models. The distributed data parallel APIs of MONAI are compatible with native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform. MONAI provides demos for reference: train/evaluate with PyTorch DDP, train/evaluate with Horovod, train/evaluate with Ignite DDP, partition dataset and train with SmartCacheDataset, as well as a real world training example based on Decathlon challenge Task01 - Brain Tumor segmentation. The [tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/brats_training_ddp.py) contains distributed caching, training, and validation. We obtained performance benchmarks for reference (based on PyTorch 1.9.1, CUDA 11.4, NVIDIA V100 GPUs. The `optimization` means that with more GPU resources, we can split the data and cache into GPU memory and execute GPU transforms directly):
-
-![distributed training results](../images/brats_distributed.png)
-
-### 4. C++/CUDA optimized modules
-To further accelerate the domain-specific routines in the workflows, MONAI C++/CUDA implementation are introduced as extensions of the PyTorch native implementations.
-MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
-- via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
-- via just-in-time (JIT) compilation, for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
-The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation:
-![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)
-
-### 5. Cache IO and transforms data to GPU memory
-Even with `CacheDataset`, we usually need to copy the same data to GPU memory for GPU random transforms or network computation in every epoch. An efficient approach is to cache the data to GPU memory directly, then every epoch can start from GPU computation immediately.
-
-For example:
-```py
-train_transforms = [
- LoadImaged(...),
- EnsureChannelFirstd(...),
- Orientationd(...),
- Spacingd(...),
- ScaleIntensityRanged(...),
- EnsureTyped(..., data_type="tensor"),
- ToDeviced(..., device="cuda:0"),
- RandCropByPosNegLabeld(...),
-]
-dataset = CacheDataset(..., transform=train_trans)
-```
-Here we convert to PyTorch `Tensor` with `EnsureTyped` transform and move data to GPU with `ToDeviced` transform. `CacheDataset` caches the transform results until `ToDeviced`, so it is in GPU memory. Then in every epoch, the program fetches cached data from GPU memory and only executes the random transform `RandCropByPosNegLabeld` on GPU directly.
-GPU caching example is available at [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
-
-## Applications
-The research area of medical image deep learning is expanding fast. To apply the latest achievements into applications, MONAI contains many application components to build end-to-end solutions or prototypes for other similar use cases.
-
-### 1. DeepGrow modules for interactive segmentation
-[A reimplementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/deepgrow) of the DeepGrow components, which is deep learning based semi-automated segmentation approach that aims to be a "smart" interactive tool for region of interest delineation in medical images, originally proposed by:
-
-Sakinis, Tomas, et al. "Interactive segmentation of medical images through fully convolutional neural networks." arXiv preprint arXiv:1903.08205 (2019).
-
-![deepgrow scheme](../images/deepgrow.png)
-
-### 2. DeepEdit workflow for interactive segmentation
-DeepEdit is a method that combines an automatic and a semi-automatic approach for 3D medical images into a single deep learning-based model. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/deepedit) of the DeepEdit modules provides essential components for interactive segmentation. More details are available in the training and inference [tutorial](https://github.com/Project-MONAI/tutorials/tree/main/deepedit/ignite).
-
-The following figure shows the typical workflow of interactive segmentation:
-
-![deepedit workflow](../images/deepedit.png)
-
-### 3. NuClick modules for interactive nuclei segmentation
-NuClick is a CNN-based approach to speed up collecting annotations for microscopic objects requiring minimum interaction from the annotator. The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/nuclick) contains essential components for the training and inference workflows of NuClick interactive nuclei segmentation.
-
-The following figure is example outputs of NuClick (annotator click inside the nucleus and the mask will be generated by CNN):
-
-![nuclick output](../images/nuclick.png)
-
-### 4. Lesion detection in digital pathology
-[Implementation](https://github.com/Project-MONAI/MONAI/tree/master/monai/apps/pathology) of the pathology detection components, which includes efficient whole slide imaging IO and several patch sampling methods with NVIDIA cuCIM library and SmartCache mechanism, FROC measurements for lesion and probabilistic post-processing for lesion detection.
-
-![digital pathology](../images/pathology.png)
-
-### 5. Learning-based image registration
-Starting from v0.5.0, MONAI provides experimental features for building learning-based 2D/3D registration workflows. These include image similarity measures as loss functions, bending energy as model regularization, network architectures, warping modules. The components can be used to build the major unsupervised and weakly-supervised algorithms.
-
-The following figure shows the registration of CT images acquired at different time points for a single patient using MONAI:
-
-![3d registration](../images/3d_paired.png)
-
-### 6. 2D and 3D detection workflow
-The [implementation](https://github.com/Project-MONAI/MONAI/tree/dev/monai/apps/detection) contains 2D and 3D bounding box detection components of `RetinaNet`, which includes:bounding box operations, hard negative sampler, and RetinaNet detectors.
-
-The following figure shows the detection training and inference workflows:
-
-![detection workflow](../images/detection.png)
-
-### 7. Reproducing the state-of-the-art Kaggle competition solutions
-[A reimplementation](https://github.com/Project-MONAI/tutorials/tree/master/kaggle/RANZCR/4th_place_solution) of the 4th place solution of RANZCR CLiP - Catheter and Line Position Challenge in Kaggle: https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification
-
-The original solution is produced by Team Watercooled, and the authors are Dieter (https://www.kaggle.com/christofhenkel) and Psi (https://www.kaggle.com/philippsinger).
diff --git a/docs/source/highlights.rst b/docs/source/highlights.rst
new file mode 100644
index 00000000000..29682a05fdb
--- /dev/null
+++ b/docs/source/highlights.rst
@@ -0,0 +1,10 @@
+:github_url: https://github.com/Project-MONAI/MONAI
+
+Highlights
+==========
+
+.. toctree::
+ :maxdepth: 1
+
+ modules.md
+ applications.md
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 6bd8097ed77..9f8c3cb7eca 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -12,19 +12,17 @@ Project MONAI
*Medical Open Network for AI*
MONAI is a `PyTorch `_-based, `open-source `_ framework
-for deep learning in healthcare imaging, part of `PyTorch Ecosystem `_.
+for deep learning in healthcare imaging, part of the `PyTorch Ecosystem `_.
Its ambitions are:
- developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
- creating state-of-the-art, end-to-end training workflows for healthcare imaging;
-- providing researchers with the optimized and standardized way to create and evaluate deep learning models.
+- providing researchers with an optimized and standardized way to create and evaluate deep learning models.
Features
--------
-*The codebase is currently under active development*
-
- flexible pre-processing for multi-dimensional medical imaging data;
- compositional & portable APIs for ease of integration in existing workflows;
- domain-specific implementations for networks, losses, evaluation metrics and more;
@@ -72,6 +70,13 @@ Technical documentation is available at `docs.monai.io `_
bundle_intro
+Model Zoo
+---------
+
+`The MONAI Model Zoo `_ is a place for researchers and data scientists to share the latest and great models from the community.
+Utilizing `the MONAI Bundle format `_ makes it easy to `get started `_ building workflows with MONAI.
+
+
Links
-----
diff --git a/docs/source/installation.md b/docs/source/installation.md
index c87234c7c7b..ca5ed9a9383 100644
--- a/docs/source/installation.md
+++ b/docs/source/installation.md
@@ -4,6 +4,7 @@
1. [From PyPI](#from-pypi)
1. [Milestone release](#milestone-release)
2. [Weekly preview release](#weekly-preview-release)
+ 3. [Uninstall the packages](#uninstall-the-packages)
1. [From conda-forge](#from-conda-forge)
2. [From GitHub](#from-github)
1. [System-wide](#milestone-release)
@@ -48,6 +49,19 @@ To report any issues on the weekly preview, please include the version and commi
python -c "import monai; print(monai.__version__); print(monai.__commit_id__)"
```
+Coexistence of package `monai` and `monai-weekly` in a system may cause namespace conflicts
+and `ImportError`.
+This is usually a result of running both `pip install monai` and `pip install monai-weekly`
+without uninstalling the existing one first.
+To address this issue, please uninstall both packages, and retry the installation.
+
+### Uninstall the packages
+The packages installed using `pip install` could be removed by:
+```bash
+pip uninstall -y monai
+pip uninstall -y monai-weekly
+```
+
## From conda-forge
To install the [current milestone release](https://anaconda.org/conda-forge/monai):
@@ -125,7 +139,7 @@ and the codebase is ready to use (without the additional features of MONAI C++/C
## Validating the install
You can verify the installation by:
```bash
-python -c 'import monai; monai.config.print_config()'
+python -c "import monai; monai.config.print_config()"
```
If the installation is successful, this command will print out the MONAI version information, and this confirms the core
modules of MONAI are ready-to-use.
@@ -196,9 +210,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are
```
-[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, pynrrd, pydicom, h5py]
+[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
-`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `pynrrd`, `pydicom`, h5py , respectively.
+`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, respectively.
- `pip install 'monai[all]'` installs all the optional dependencies.
diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst
index aea3f1789a7..d8da890276c 100644
--- a/docs/source/metrics.rst
+++ b/docs/source/metrics.rst
@@ -17,6 +17,20 @@ Metrics
.. autoclass:: Metric
:members:
+`Variance`
+--------------
+.. autofunction:: compute_variance
+
+.. autoclass:: VarianceMetric
+ :members:
+
+`LabelQualityScore`
+--------------------
+.. autofunction:: label_quality_score
+
+.. autoclass:: LabelQualityScore
+ :members:
+
`IterationMetric`
-----------------
.. autoclass:: IterationMetric
@@ -90,6 +104,13 @@ Metrics
.. autoclass:: SurfaceDiceMetric
:members:
+`PanopticQualityMetric`
+-----------------------
+.. autofunction:: compute_panoptic_quality
+
+.. autoclass:: PanopticQualityMetric
+ :members:
+
`Mean squared error`
--------------------
.. autoclass:: MSEMetric
diff --git a/docs/source/modules.md b/docs/source/modules.md
new file mode 100644
index 00000000000..0063466b807
--- /dev/null
+++ b/docs/source/modules.md
@@ -0,0 +1,332 @@
+# Modules
+
+MONAI aims at facilitating deep learning in medical image analysis at multiple granularities. This document provides an
+overview of the modules and highlights the key capabilities.
+
+The core codebase is designed as a library of lightweight, flexible, and comprehensive APIs for users with varying expertise.
+The building blocks are made easy to understand and use, they are carefully decoupled and can be readily integrated
+into existing PyTorch programs and larger systems. By leveraging the workflow and bundle APIs, users can also quickly
+set up efficient and robust model training or evaluation pipelines for various domain-specific applications.
+
+The overall architecture and modules are shown in the following figure:
+
+![architecture overview](../images/arch_modules.png)
+
+* [I/O, processing and augmentation](#i-o-processing-and-augmentation)
+* [Datasets and Data Loading](#datasets-and-data-loading)
+* [Differentiable components, networks, losses and optimizers](#differentiable-components-networks-losses-and-optimizers)
+* [Evaluation](#evaluation)
+* [Visualization](#visualization)
+* [Workflows](#workflows)
+* [Bundle](#bundle)
+* [Federated Learning](#federated-learning)
+* [Auto3dseg](#auto3dseg)
+* [GPU acceleration, performance profiling and optimization](#gpu-acceleration-performance-profiling-and-optimization)
+
+## I/O, processing and augmentation
+Medical images require specialized methods for I/O, preprocessing and augmentation. They often follow specific formats,
+are handled with specific protocols, and the data arrays are often high-dimensional.
+[`monai.transforms`](https://github.com/Project-MONAI/MONAI/tree/dev/monai/transforms) and
+[`monai.data`](https://github.com/Project-MONAI/MONAI/tree/dev/monai/data) modules include a set of domain-specific APIs
+for various deep learning applications:
+
+### Transforms with data in array and dictionary styles
+
+![3d transform examples](../images/affine.png)
+
+This enables basic image transformations, as well as more complex preprocessing pipelines such as synchronized operations
+across different modalities and model supervision inputs. [[array and dict examples]](https://github.com/Project-MONAI/tutorials/tree/main/3d_segmentation/torch)
+
+### Various image patch-based sampling mechanisms
+
+![2d transform examples](../images/medical_transforms.png)
+
+Advanced patch sampling methods are implemented for selective preprocessing, such as weighted, class-balanced sampling
+from user-specified sampling weight maps.
+The output can be in a sequence or iterator pattern which allows for different types of shuffling strategies.
+
+### Image IO with third-party library integrations
+
+Several backends are built-in and can support various formats. It is easily extensible for customized format readers.
+
+### monai.data.MetaTensor
+
+Core data structure combines PyTorch native Tensor APIs with metadata handling,
+so that the deep learning models and pipelines can readily incorporate the meta information. [[MetaTensor]](https://colab.research.google.com/drive/1T4iAys-cC2qL80oJkIbAXAPlWNPwp4H7)
+
+### GPU-based accelerations
+
+Implementations are provided to ensure optimal usage of the underlying hardware resources. [[fast training guide]](https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_model_training_guide.md)
+
+### Determinism and reproducibility
+
+They can be achieved with fine-level of local controls via the `Randomizable` API as well as globally
+using `set_determinism`.
+
+### Decollating and invertible transforms
+
+![invert transform](../images/invert_transforms.png)
+The mini-batch data output from a model can be decollated, post-processed independently, including inverting
+the outputs to an earlier step of the preprocessing according to the tracked metadata and applied operations.
+[[inverse transform demo]](https://github.com/Project-MONAI/tutorials/blob/main/modules/inverse_transforms_and_test_time_augmentations.ipynb)
+
+### Enhanced usability
+
+Additionally, utilities such as `DataStats` transform, `dev_collate`, and [visualization
+methods](https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb) are provided as
+extensions to PyTorch for improved overall debugability.
+
+## Datasets and Data Loading
+Following PyTorch's design pattern, MONAI extends the `Dataset` and `DataLoader` APIs as major enhancements in terms of
+domain-specific usability and pipeline performance.
+
+### Cache IO and transforms data to accelerate training
+
+Data-driven methods require many (potentially thousands of) epochs of training data reading and preprocessing. MONAI
+provides multi-threaded cache-based datasets to accelerate the process [[Datasets experiment]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb). The
+cache can be persistent and dynamic (`SmartCacheDataset`) and reused across different experiments [[SmartCache example]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/unet_training_smartcache.py).
+The following figure illustrates the training speedup compared with a regular PyTorch program.
+
+![cachedataset speed](../images/datasets_speed.png)
+
+### `ThreadDataLoader` vs. `DataLoader`
+
+If the transforms are light-weighted, especially when we cache all the data in RAM, the multiprocessing of PyTorch
+`DataLoader` may cause unnecessary IPC time and decrease GPU utilization. MONAI provides `ThreadDataLoader` which
+executes the transforms in a separate thread:
+
+![threaddataloader](../images/threaddataloader.png)
+
+a `ThreadDataLoader` example is within the [Spleen fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb).
+
+### Public datasets
+
+To quickly get started with popular training data, MONAI provides several ready-to-integrate Dataset classes
+(such as `MedNISTDataset`, `DecathlonDataset`, [`TciaDataset`](https://github.com/Project-MONAI/tutorials/blob/main/modules/tcia_dataset.ipynb)), which include data downloading, and support training/evaluation splits generation with transforms.
+[[Public datasets tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/public_datasets.ipynb)
+The common workflow of predefined datasets:
+
+![pre-defined dataset](../images/dataset_progress.png)
+
+### Dataset type extensions
+
+Other extensions of the `Dataset` API include: `ZipDataset` for associating multiple data sources, `PatchDataset` for
+handling both image- and patch-level preprocessing, `CSVDataset` for multi-modal inputs, and `partition_dataset` for
+cross-validation data preparations.
+
+## Differentiable components, networks, losses and optimizers
+
+Some deep neural network architectures have shown to be particularly effective for medical imaging analysis tasks.
+MONAI implements reference networks with the aim of both flexibility and code readability.
+
+### Predefined layers and blocks
+
+Network layers and blocks are in general implemented to be compatible with spatial 1D, 2D and 3D inputs.
+Users can easily integrate the layers, blocks and networks as part of their customised pipelines.
+Various utilities are provided to leverage the existing model weights, e.g., finetuning [from MMAR](https://github.com/Project-MONAI/tutorials/blob/master/modules/transfer_mmar.ipynb)
+or [from a bundle in MONAI model-zoo](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).
+
+### C++/CUDA optimized modules
+
+To further accelerate the domain-specific routines, MONAI C++/CUDA implementation is introduced as extensions of the PyTorch native implementations.
+MONAI provides the modules using [the two ways of building C++ extensions from PyTorch](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions):
+- via `setuptools`, for modules including `Resampler`, `Conditional random field (CRF)`, `Fast bilateral filtering using the permutohedral lattice`.
+- via just-in-time (JIT) compilation, for the `Gaussian mixtures` module. This approach allows for dynamic optimisation according to the user-specified parameters and local system environments.
+The following figure shows results of MONAI's Gaussian mixture models applied to tissue and surgical tools segmentation:
+![Gaussian mixture models as a postprocessing step](../images/gmm_feature_set_comparison_s.png)
+
+
+### Losses and optimizers
+
+Commonly used loss functions for various applications are (re-)implemented from the literature, such as `DiceLoss`, `GeneralizedDiceLoss`, `TverskyLoss`, `DiceFocalLoss`.
+The numerical optimizations and relevant utilities include `Novograd` and `LearningRateFinder`.
+The following figure shows a learning rate search process.
+
+![learning rate finder plot](../images/lr_finder.png)
+
+## Evaluation
+To run model inferences and evaluate the model quality, MONAI provides reference implementations for the relevant
+widely-used approaches. Currently, several popular evaluation metrics and inference patterns are included:
+
+### Sliding window inference
+
+For model inferences on large volumes, the sliding window approach is a popular choice to achieve high performance while
+having flexible memory requirements (_alternatively, please check out the latest research on [model parallel
+training](#lamp-large-deep-nets-with-automated-model-parallelism-for-image-segmentation) using MONAI_). It also supports
+`overlap` and `blending_mode` configurations to handle the overlapped windows for better performances.
+
+![sliding window scheme](../images/sliding_window.png)
+
+### Metrics for medical tasks
+
+Various useful evaluation metrics have been implemented to measure the quality of medical image specific models.
+These include `Mean Dice`, `ROCAUC`, `Confusion Matrices`, `Hausdorff
+Distance`, `Surface Distance`, `Occlusion Sensitivity`.
+The APIs also support [multi-processing computation](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py).
+
+### Report generation
+`MetricsSaver` is provided to write the final metric summary report: `mean`, `median`, `max`, `min`, `percentile`, `std`:
+
+![metrics report example](../images/metrics_report.png)
+
+## Visualization
+Beyond the simple point and curve plotting, intuitive interfaces are provided to visualize multidimensional data as GIF animations in TensorBoard. This could provide a quick qualitative assessment of the model by visualizing, for example, the volumetric inputs, segmentation maps, and intermediate feature maps. A runnable example with visualization is available at [UNet training example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py). To work with ignite program, MONAI also provides several ignite handlers to visualize training curve and metrics with `TensorBoard` or `MLFlow`, more details is available in [TensorBoard and MLFlow handlers example](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb).
+
+To easily visualize a 3D image as frames of 2D images, MONAI provides the utility `matshow3d` based on `matplotlib` library. It can plot frames of image for the specified dimension, showing a spleen 3D image as example:
+`matshow3d(volume=image, figsize=(100, 100), every_n=10, frame_dim=-1 show=True, cmap="gray")`
+
+![matshow3d example](../images/matshow3d.png)
+
+MONAI also provides the `blend_images` utility to blend the `image` and `label` to an RGB color image to better visualize the segmentation regions with the specified `cmap` mode and weights, etc. Showing a spleen segmentation `image` and the corresponding `label` as example:
+
+![blend example](../images/blend.png)
+
+For more details of `TensorBoard utility`, `matshow3d` and `blend_images`, please check the [visualization tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/transform_visualization.ipynb).
+
+And to visualize the class activation mapping for a trained classification model, MONAI provides CAM, GradCAM, GradCAM++ APIs for both 2D and 3D models:
+
+![CAM visualization example](../images/cam.png)
+
+The above example is generated by computing [GradCAM/GradCAM++ from a lung CT lesion classification model](https://github.com/Project-MONAI/tutorials/tree/master/modules/interpretability).
+
+## Workflows
+
+MONAI engines and workflows enable quick start of training and evaluation experiments.
+
+These features decouple the domain-specific components and the generic machine learning processes.
+They also provide a set of unify APIs for higher level applications (such as AutOML, Federated Learning).
+The trainers and evaluators of the workflows are compatible with pytorch-ignite `Engine` and `Event-Handler` mechanism.
+
+### General workflows pipeline
+
+The workflow and some of MONAI event handlers are shown as below [[Workflow examples]](https://github.com/Project-MONAI/tutorials/tree/master/modules/engines):
+
+![workflow pipeline](../images/workflows.png)
+
+
+### EnsembleEvaluator
+
+A typical ensemble procoess is implemented as a ready-to-use workflow [[Cross validation and model ensemble tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/cross_validation_models_ensemble.ipynb):
+1. Split all the training dataset into K folds.
+2. Train K models with every K-1 folds data.
+3. Execute inference on the test data with all the K models.
+4. Compute the average values with weights or vote the most common value as the final result.
+
+![model ensemble](../images/models_ensemble.png)
+
+
+### Decollate batch data for flexible post-processings
+
+`decollate batch` is introduced since MONAI v0.6, which simplifies the post-processing transforms and provides flexible following operations on a batch of data with various data shapes. It can decollate batched data (e.g. model predictions) into a list of tensors, for the benefits such as:
+1. enabling postprocessing transforms for each item independently -- randomised transforms could be applied differently for each predicted item in a batch.
+2. simplifying the transform APIs and reducing the input validation burdens because both the preprocessing and postprocessing transforms now only need to support the "channel-first" input format.
+3. enabling the `Invertd` transform for the predictions and the inverted data with different shapes, as the data items are in a list, not stacked in a single tensor.
+4. allowing for both batch-first tensor and list of channel-first tensors in a flexible metric computation. [[decollate batch tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/modules/decollate_batch.ipynb)
+
+A typical process of `decollate batch` is illustrated as follows (with a `batch_size=N` model predictions and labels as an example):
+
+![decollate_batch](../images/decollate_batch.png)
+
+### Easy to integrate into popular workflows
+
+Except for the pytorch-ignite based `monai.engines`, most of the MONAI modules could be used independently or combined
+with other software packages. For example, MONAI can be easily integrated into popular frameworks such as
+PyTorch-Lightning and Catalyst. [[Lightning segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d_lightning.ipynb),
+[Catalyst segmentation](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_catalyst.ipynb)]
+
+## Bundle
+
+The objective of a MONAI bundle is to define a packaged model which includes the critical information necessary to allow
+users and programs to understand how the model is used and for what purpose. A bundle includes the stored weights of a
+single network as a pickled state dictionary plus optionally a Torchscript object and/or an ONNX object. Additional JSON
+files are included to store metadata about the model, information for constructing training, inference, and
+post-processing transform sequences, plain-text description, legal information, and other data the model creator wishes
+to include. More details are available at [bundle specification](https://docs.monai.io/en/latest/mb_specification.html).
+
+The key benefits of bundle are to define the model package and support building Python-based workflows via structured configurations:
+- Self-contained model package include all the necessary information.
+- Structured config can be used to easily reconstruct or prototype deep learning workflows.
+- Config files can provide good readability and usability by separating parameter settings from the Python code.
+- Config files can describe flexible workflow and components, allows for different low-level Python implementations
+- Learning paradigms at a higher level such as federated learning and AutoML can be decoupled from the component details.
+
+A typical bundle example can include:
+```
+ ModelName
+ ┣━ configs
+ ┃ ┗━ metadata.json
+ ┣━ models
+ ┃ ┣━ model.pt
+ ┃ ┣━ *model.ts
+ ┃ ┗━ *model.onnx
+ ┗━ docs
+ ┣━ *README.md
+ ┗━ *license.txt
+```
+Details about the bundle config definition and syntax & examples are at [config syntax](https://docs.monai.io/en/latest/config_syntax.html).
+A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/master/modules/bundles/get_started.ipynb) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)]
+
+## Federated Learning
+
+![federated-learning](../images/federated.svg)
+
+Using the MONAI bundle configurations, we can use MONAI's [`MonaiAlgo`](https://docs.monai.io/en/latest/fl.html#monai.fl.client.MonaiAlgo)
+class, an implementation of the abstract [`ClientAlgo`](https://docs.monai.io/en/latest/fl.html#clientalgo) class for federated learning (FL),
+to execute bundles from the [MONAI model zoo](https://github.com/Project-MONAI/model-zoo).
+Note that [`ClientAlgo`](https://docs.monai.io/en/latest/fl.html#clientalgo) is provided as an abstract base class for
+defining an algorithm to be run on any federated learning platform.
+[`MonaiAlgo`](https://docs.monai.io/en/latest/fl.html#monai.fl.client.MonaiAlgo) implements the main functionalities needed
+to run federated learning experiments, namely `train()`, `get_weights()`, and `evaluate()`, that can be run using single- or multi-GPU training.
+On top, it provides implementations for life-cycle management of the component such as `initialize()`, `abort()`, and `finalize()`.
+The MONAI FL client also allows computing summary data statistics (e.g., intensity histograms) on the datasets defined in the bundle configs
+using the [`MonaiAlgoStats`](https://docs.monai.io/en/latest/fl.html#monai.fl.client.MonaiAlgoStats) class.
+These statistics can be shared and visualized on the FL server.
+[NVIDIA FLARE](https://github.com/NVIDIA/NVFlare), the federated learning platform developed by NVIDIA, has already built [the integration piece](https://github.com/NVIDIA/NVFlare/tree/2.2/integration/monai)
+with [`ClientAlgo`](https://docs.monai.io/en/latest/fl.html#clientalgo) to allow easy experimentation with MONAI bundles within their federated environment.
+Our [[federated learning tutorials]](https://github.com/Project-MONAI/tutorials/tree/main/federated_learning/nvflare) shows
+examples of single- & multi-GPU training and federated statistics workflows.
+
+## Auto3dseg
+
+![auto3dseg](../images/auto3dseg.png)
+
+[Auto3DSeg](https://monai.io/apps/auto3dseg.html) is a comprehensive solution for large-scale 3D medical image segmentation.
+It leverages the latest advances in MONAI
+and GPUs to efficiently develop and deploy algorithms with state-of-the-art performance.
+It first analyzes the global information such as intensity, dimensionality, and resolution of the dataset,
+then generates algorithms in MONAI bundle format based on data statistics and [algorithm templates](https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg).
+Next, all algorithms initiate model training to obtain checkpoints with the best validation performance.
+Finally, the ensemble module selects the algorithms via ranking trained checkpoints and creates ensemble predictions.
+
+The solution offers different levels of user experience for beginners and advanced researchers.
+It has been tested on large-scale 3D medical imaging datasets in different modalities.
+
+## GPU acceleration, performance profiling and optimization
+
+MONAI provides state-of-the-art performance optimization methods including:
+
+### Auto mixed precision (AMP)
+
+Simply set `amp=True/False` in `SupervisedTrainer` or `SupervisedEvaluator` during training or evaluation to enable/disable AMP
+Example benchmark results are as follows [[AMP training tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb):
+
+training with AMP ON/OFF on a NVIDIA V100 GPU with CUDA 11 and PyTorch 1.6:
+
+![amp v100 results](../images/amp_training_v100.png)
+
+training with AMP ON/OFF on a NVIDIA A100 GPU with CUDA 11 and PyTorch 1.6:
+
+![amp a100 results](../images/amp_training_a100.png)
+
+Several tools including `DLProf`, `Nsight`, `NVTX` and `NVML` can be used with MONAI to identify the performance bottleneck. [[profiling tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/performance_profiling/radiology/profiling_train_base_nvtx.md)
+
+### Distributed training
+
+The distributed data-parallel APIs of MONAI are compatible with the native PyTorch distributed module, pytorch-ignite distributed module, Horovod, XLA, and the SLURM platform.
+[[distributed training tutorial]](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/distributed_training/brats_training_ddp.py)
+
+![distributed training results](../images/brats_distributed.png)
+
+The [fast training tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_training_tutorial.ipynb)
+combines `AMP` with `CacheDataset`, `GPU cache`, `GPU transforms`, `ThreadDataLoader`, tuning of networks and optimizers, can achieve substantial speedup compared
+with a PyTorch regular implementation.
diff --git a/docs/source/networks.rst b/docs/source/networks.rst
index 1b89f2a3298..a4c225de292 100644
--- a/docs/source/networks.rst
+++ b/docs/source/networks.rst
@@ -89,6 +89,11 @@ Blocks
.. autoclass:: UnetOutBlock
:members:
+`DenseBlock`
+~~~~~~~~~~~~~
+.. autoclass:: DenseBlock
+ :members:
+
`SegResnet Block`
~~~~~~~~~~~~~~~~~
.. autoclass:: ResBlock
@@ -233,6 +238,10 @@ Blocks
.. autoclass:: DVF2DDF
:members:
+`VarNetBlock`
+~~~~~~~~~~~~~
+.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
+ :members:
N-Dim Fourier Transform
~~~~~~~~~~~~~~~~~~~~~~~~
@@ -325,6 +334,16 @@ Layers
.. autoclass:: GaussianFilter
:members:
+`MedianFilter`
+~~~~~~~~~~~~~~
+.. autoclass:: MedianFilter
+ :members:
+
+`median_filter`
+~~~~~~~~~~~~~~~
+.. autoclass:: median_filter
+ :members:
+
`BilateralFilter`
~~~~~~~~~~~~~~~~~
.. autoclass:: BilateralFilter
@@ -519,6 +538,18 @@ Nets
.. autoclass:: BasicUnet
.. autoclass:: Basicunet
+`BasicUNetPlusPlus`
+~~~~~~~~~~~~~~~~~~~
+.. autoclass:: BasicUNetPlusPlus
+ :members:
+.. autoclass:: BasicUnetPlusPlus
+.. autoclass:: BasicunetPlusPlus
+
+`FlexibleUNet`
+~~~~~~~~~~~~~~
+.. autoclass:: FlexibleUNet
+ :members:
+
`VNet`
~~~~~~
.. autoclass:: VNet
@@ -644,6 +675,11 @@ Nets
.. autoclass:: monai.apps.reconstruction.networks.nets.coil_sensitivity_model.CoilSensitivityModel
:members:
+`e2e-VarNet`
+~~~~~~~~~~~~
+.. autoclass:: monai.apps.reconstruction.networks.nets.varnet.VariationalNetworkModel
+ :members:
+
Utilities
---------
.. automodule:: monai.networks.utils
diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst
index f8dc318e095..7b728fde487 100644
--- a/docs/source/transforms.rst
+++ b/docs/source/transforms.rst
@@ -22,11 +22,31 @@ Generic Interfaces
:members:
:special-members: __call__
+`RandomizableTrait`
+^^^^^^^^^^^^^^^^^^^
+.. autoclass:: RandomizableTrait
+ :members:
+
+`LazyTrait`
+^^^^^^^^^^^
+.. autoclass:: LazyTrait
+ :members:
+
+`MultiSampleTrait`
+^^^^^^^^^^^^^^^^^^
+.. autoclass:: MultiSampleTrait
+ :members:
+
`Randomizable`
^^^^^^^^^^^^^^
.. autoclass:: Randomizable
:members:
+`LazyTransform`
+^^^^^^^^^^^^^^^
+.. autoclass:: LazyTransform
+ :members:
+
`RandomizableTransform`
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomizableTransform
@@ -336,6 +356,14 @@ Intensity
:members:
:special-members: __call__
+`MedianSmooth`
+""""""""""""""
+.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MedianSmooth.png
+ :alt: example of MedianSmooth
+.. autoclass:: MedianSmooth
+ :members:
+ :special-members: __call__
+
`GaussianSmooth`
""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmooth.png
@@ -461,6 +489,13 @@ Intensity
:members:
:special-members: __call__
+`ComputeHoVerMaps`
+""""""""""""""""""
+.. autoclass:: ComputeHoVerMaps
+ :members:
+ :special-members: __call__
+
+
IO
^^
@@ -530,6 +565,14 @@ Post-processing
:members:
:special-members: __call__
+`RemoveSmallObjects`
+""""""""""""""""""""
+.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RemoveSmallObjects.png
+ :alt: example of RemoveSmallObjects
+.. autoclass:: RemoveSmallObjects
+ :members:
+ :special-members: __call__
+
`LabelFilter`
"""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilter.png
@@ -563,12 +606,87 @@ Post-processing
.. autoclass:: ProbNMS
:members:
+`SobelGradients`
+""""""""""""""""
+.. autoclass:: SobelGradients
+ :members:
+ :special-members: __call__
+
`VoteEnsemble`
""""""""""""""
.. autoclass:: VoteEnsemble
:members:
:special-members: __call__
+Signal
+^^^^^^^
+
+`SignalRandDrop`
+""""""""""""""""
+.. autoclass:: SignalRandDrop
+ :members:
+ :special-members: __call__
+
+`SignalRandScale`
+"""""""""""""""""
+.. autoclass:: SignalRandScale
+ :members:
+ :special-members: __call__
+
+`SignalRandShift`
+"""""""""""""""""
+.. autoclass:: SignalRandShift
+ :members:
+ :special-members: __call__
+
+`SignalRandAddSine`
+"""""""""""""""""""
+.. autoclass:: SignalRandAddSine
+ :members:
+ :special-members: __call__
+
+`SignalRandAddSquarePulse`
+""""""""""""""""""""""""""
+.. autoclass:: SignalRandAddSquarePulse
+ :members:
+ :special-members: __call__
+
+`SignalRandAddGaussianNoise`
+""""""""""""""""""""""""""""
+.. autoclass:: SignalRandAddGaussianNoise
+ :members:
+ :special-members: __call__
+
+`SignalRandAddSinePartial`
+""""""""""""""""""""""""""
+.. autoclass:: SignalRandAddSinePartial
+ :members:
+ :special-members: __call__
+
+`SignalRandAddSquarePulsePartial`
+"""""""""""""""""""""""""""""""""
+.. autoclass:: SignalRandAddSquarePulsePartial
+ :members:
+ :special-members: __call__
+
+`SignalFillEmpty`
+"""""""""""""""""
+.. autoclass:: SignalFillEmpty
+ :members:
+ :special-members: __call__
+
+`SignalRemoveFrequency`
+"""""""""""""""""""""""
+.. autoclass:: SignalRemoveFrequency
+ :members:
+ :special-members: __call__
+
+`SignalContinuousWavelet`
+"""""""""""""""""""""""""
+.. autoclass:: SignalContinuousWavelet
+ :members:
+ :special-members: __call__
+
Spatial
^^^^^^^
@@ -1325,6 +1443,14 @@ Intensity (Dict)
:members:
:special-members: __call__
+`MedianSmoothd`
+"""""""""""""""
+.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/MedianSmoothd.png
+ :alt: example of MedianSmoothd
+.. autoclass:: MedianSmoothd
+ :members:
+ :special-members: __call__
+
`GaussianSmoothd`
"""""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/GaussianSmoothd.png
@@ -1397,6 +1523,12 @@ Intensity (Dict)
:members:
:special-members: __call__
+`ComputeHoVerMapsd`
+"""""""""""""""""""
+.. autoclass:: ComputeHoVerMapsd
+ :members:
+ :special-members: __call__
+
IO (Dict)
^^^^^^^^^
@@ -1437,6 +1569,14 @@ Post-processing (Dict)
:members:
:special-members: __call__
+`RemoveSmallObjectsd`
+"""""""""""""""""""""
+.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/RemoveSmallObjectsd.png
+ :alt: example of RemoveSmallObjectsd
+.. autoclass:: RemoveSmallObjectsd
+ :members:
+ :special-members: __call__
+
`LabelFilterd`
""""""""""""""
.. image:: https://github.com/Project-MONAI/DocImages/raw/main/transforms/LabelFilterd.png
@@ -1495,6 +1635,14 @@ Post-processing (Dict)
:members:
:special-members: __call__
+
+`SobelGradientsd`
+"""""""""""""""""
+.. autoclass:: SobelGradientsd
+ :members:
+ :special-members: __call__
+
+
Spatial (Dict)
^^^^^^^^^^^^^^
diff --git a/docs/source/visualize.rst b/docs/source/visualize.rst
index 3779feec887..1860b65e03d 100644
--- a/docs/source/visualize.rst
+++ b/docs/source/visualize.rst
@@ -25,7 +25,15 @@ Occlusion sensitivity
.. automodule:: monai.visualize.occlusion_sensitivity
:members:
+Gradient-based saliency maps
+----------------------------
+
+.. automodule:: monai.visualize.gradient_based
+ :members:
+
+
Utilities
---------
+
.. automodule:: monai.visualize.utils
:members:
diff --git a/docs/source/whatsnew.rst b/docs/source/whatsnew.rst
index 05c853b9cd3..bb6665e6216 100644
--- a/docs/source/whatsnew.rst
+++ b/docs/source/whatsnew.rst
@@ -6,6 +6,7 @@ What's New
.. toctree::
:maxdepth: 1
+ whatsnew_1_0.md
whatsnew_0_9.md
whatsnew_0_8.md
whatsnew_0_7.md
diff --git a/docs/source/whatsnew_0_9.md b/docs/source/whatsnew_0_9.md
index fa58630bdce..357dc01b355 100644
--- a/docs/source/whatsnew_0_9.md
+++ b/docs/source/whatsnew_0_9.md
@@ -1,4 +1,4 @@
-# What's new in 0.9 🎉🎉
+# What's new in 0.9
- MONAI Bundle
- Object detection in medical images
@@ -17,7 +17,7 @@ The key benefits of Bundle and the `monai.bundle` APIs are:
- Flexible config components to allow for different low-level Python implementations,
- Help to decouple the component details from higher level learning paradigms such as federated learning and AutoML.
-More details are [in the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/modules/bundle).
+More details are [in the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/bundle).
## Object detection in medical images
This release includes essential components for object localization and categorization workflows.
@@ -57,5 +57,5 @@ especially for the data-driven approaches that MONAI has been focusing.
Starting from this release, we roll out a major refactoring for data representation in MONAI. For the first
step, [the core data structures](https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/meta_tensor.py)
`MetaTensor` and `MetaObj` are implemented as a feature preview.
-Further developments [on the feature branch](https://github.com/Project-MONAI/MONAI/tree/feature/MetaTensor)
+Further developments [on the feature branch](https://github.com/Project-MONAI/MONAI/pull/4539)
will be made available in future milestone releases.
diff --git a/docs/source/whatsnew_1_0.md b/docs/source/whatsnew_1_0.md
new file mode 100644
index 00000000000..36ab393af12
--- /dev/null
+++ b/docs/source/whatsnew_1_0.md
@@ -0,0 +1,67 @@
+# What's new in 1.0 🎉🎉
+
+- Model Zoo
+- Auto3DSeg
+- Federated Learning Client
+- MetaTensor Support for Digital Pathology Workflows
+- Accelerated MRI Reconstruction
+
+
+## Model Zoo
+The MONAI Model Zoo is a place for researchers and data scientists to use and share the latest and great models from the community.
+Utilizing [the MONAI Bundle format](https://github.com/Project-MONAI/tutorials/tree/main/bundle) makes it easy to quickly get started using any model with any MONAI Framework (Core, Label, or Deploy).
+Or, if you're interested in [contributing your models](https://github.com/project-monai/model-zoo), take a look at our contributing guidelines,
+which walks you through the process and requirements for submitting your model.
+For more details about how to use the models, please see [the tutorials](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo).
+
+## Auto3DSeg
+![auto3dseg](../images/auto3dseg.png)
+
+[Auto3DSeg](https://monai.io/apps/auto3dseg.html) is a comprehensive solution for large-scale 3D medical image segmentation.
+It leverages the latest advances in MONAI
+and GPUs to efficiently develop and deploy algorithms with state-of-the-art performance.
+It first analyzes the global information such as intensity, dimensionality, and resolution of the dataset,
+then generates algorithms in MONAI bundle format based on data statistics and [algorithm templates](https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg).
+Next, all algorithms initiate model training to obtain checkpoints with the best validation performance.
+Finally, the ensemble module selects the algorithms via ranking trained checkpoints and creates ensemble predictions.
+
+The solution offers different levels of user experience for beginners and advanced researchers.
+It has been tested on large-scale 3D medical imaging datasets in different modalities.
+
+## Federated Learning Client
+![federated-learning](../images/federated.svg)
+
+MONAI now includes the federated learning (FL) client algorithm APIs that are exposed as an abstract base class
+for defining an algorithm to be run on any federated learning platform.
+[NVIDIA FLARE](https://github.com/NVIDIA/NVFlare), the federated learning platform developed by [NVIDIA](https://www.nvidia.com/en-us/),
+has already built [the integration piece](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai) with these new APIs.
+With [the new federated learning APIs](https://docs.monai.io/en/latest/fl.html), MONAI bundles can seamlessly be extended to a federated paradigm
+and executed using single- or multi-GPU training.
+The MONAI FL client also allows computing summary data statistics (e.g., intensity histograms) on the datasets defined in the bundle configs.
+These can be shared and visualized on the FL server, for example, using NVIDIA FLARE's federated statistics operators,
+see [here](https://github.com/NVIDIA/NVFlare/tree/dev/integration/monai/examples/spleen_ct_segmentation) for an example.
+
+We welcome other federated learning toolkits to integrate with MONAI FL APIs, building a common foundation for
+collaborative learning in medical imaging.
+
+## MetaTensor Support for Digital Pathology Workflows
+![pathology](../images/pathology-meta.png)
+
+In this release, we support MetaTensor in all digital pathology components, and
+make sure that the future development can benefit from them. With the help of
+MONAI Pathology Working Group, we have standardized a set of metadata
+attributes for patches of images extracted from WSI to ensure reproducibility
+and enhance functionality via relying on a standard set of attributes. The
+figure above shows all the pathology metadata attributes and their relation to
+MetaTensors. Please see [the tutorials and
+examples](https://github.com/Project-MONAI/tutorials/tree/main/pathology).
+
+## Accelerated MRI Reconstruction
+![MRI-reconstruction](../images/mri_recon.png)
+
+This release includes initial components for various popular accelerated MRI reconstruction workflows.
+Many of them are general-purpose tools, for example the [`SSIMLoss`](https://docs.monai.io/en/latest/losses.html?highlight=ssimloss#ssimloss) function.
+Some new functionalities are task-specific, for example [`FastMRIReader`](https://docs.monai.io/en/latest/data.html?highlight=fastmri#monai.apps.reconstruction.fastmri_reader.FastMRIReader).
+
+For more details, please see [this tutorial](https://github.com/Project-MONAI/tutorials/tree/main/reconstruction/MRI_reconstruction/unet_demo) for using a baseline model for this task,
+and [this tutorial](https://github.com/Project-MONAI/tutorials/tree/main/reconstruction/MRI_reconstruction/varnet_demo) for using a state-of-the-art model.
diff --git a/environment-dev.yml b/environment-dev.yml
index 5b0ac2d9220..a11659eff7f 100644
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -5,15 +5,15 @@ channels:
- conda-forge
dependencies:
- numpy>=1.17
- - pytorch>=1.6
+ - pytorch>=1.8
- coverage>=5.5
- parameterized
- setuptools>=50.3.0,!=60.0.0
- - ignite==0.4.8
+ - ignite==0.4.10
- gdown>=4.4.0
- scipy
- nibabel
- - pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
+ - pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
- tensorboard
- scikit-image>=0.14.2
- tqdm>=4.47.0
@@ -38,7 +38,7 @@ dependencies:
- pandas
- requests
- einops
- - transformers
+ - transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
- mlflow
- tensorboardX
- pyyaml
@@ -47,6 +47,7 @@ dependencies:
- pynrrd
- pydicom
- h5py
+ - optuna
- pip
- pip:
# pip for itk as conda-forge version only up to v5.1
@@ -59,7 +60,8 @@ dependencies:
# https://github.com/conda/conda/issues/8089
- pytype>=2020.6.1; platform_system != "Windows"
- openslide-python==1.1.2
- - cucim>=21.8.2; platform_system == "Linux"
+ - cucim>=22.8.1; platform_system == "Linux"
- imagecodecs; platform_system == "Linux"
- tifffile; platform_system == "Linux"
- matplotlib!=3.5.0
+ - nni
diff --git a/monai/README.md b/monai/README.md
index 2c30531bf39..a1e36c62100 100644
--- a/monai/README.md
+++ b/monai/README.md
@@ -2,6 +2,8 @@
* **apps**: high level medical domain specific deep learning applications.
+* **auto3dseg**: automated machine learning (AutoML) components for volumetric image analysis.
+
* **bundle**: components to build the portable self-descriptive model bundle.
* **config**: for system configuration and diagnostic output.
@@ -12,6 +14,8 @@
* **engines**: engine-derived classes for extending Ignite behaviour.
+* **fl**: federated learning components to allow pipeline integration with any federated learning framework.
+
* **handlers**: defines handlers for implementing functionality at various stages in the training process.
* **inferers**: defines model inference methods.
diff --git a/monai/__init__.py b/monai/__init__.py
index e56a2f34440..3f6c06d82d1 100644
--- a/monai/__init__.py
+++ b/monai/__init__.py
@@ -39,7 +39,17 @@
# handlers_* have some external decorators the users may not have installed
# *.so files and folder "_C" may not exist when the cpp extensions are not compiled
-excludes = "(^(monai.handlers))|(^(monai.bundle))|((\\.so)$)|(^(monai._C))"
+excludes = "|".join(
+ [
+ "(^(monai.handlers))",
+ "(^(monai.bundle))",
+ "(^(monai.fl))",
+ "((\\.so)$)",
+ "(^(monai._C))",
+ "(.*(__main__)$)",
+ "(.*(video_dataset)$)",
+ ]
+)
# load directory modules only, skip loading individual files
load_submodules(sys.modules[__name__], False, exclude_pattern=excludes)
@@ -49,10 +59,12 @@
__all__ = [
"apps",
+ "auto3dseg",
"bundle",
"config",
"data",
"engines",
+ "fl",
"handlers",
"inferers",
"losses",
diff --git a/monai/_extensions/gmm/gmm.cpp b/monai/_extensions/gmm/gmm.cpp
index 4087095340a..577e5b117ef 100644
--- a/monai/_extensions/gmm/gmm.cpp
+++ b/monai/_extensions/gmm/gmm.cpp
@@ -58,12 +58,10 @@ torch::Tensor apply(torch::Tensor gmm_tensor, torch::Tensor input_tensor) {
unsigned int batch_count = input_tensor.size(0);
unsigned int element_count = input_tensor.stride(1);
- long int* output_size = new long int[dim];
- memcpy(output_size, input_tensor.sizes().data(), dim * sizeof(long int));
+ auto output_size = input_tensor.sizes().vec();
output_size[1] = MIXTURE_COUNT;
torch::Tensor output_tensor =
- torch::empty(c10::IntArrayRef(output_size, dim), torch::dtype(torch::kFloat32).device(device_type));
- delete output_size;
+ torch::empty(c10::IntArrayRef(output_size), torch::dtype(torch::kFloat32).device(device_type));
const float* gmm = gmm_tensor.data_ptr();
const float* input = input_tensor.data_ptr();
diff --git a/monai/_version.py b/monai/_version.py
index 79f569dd79c..256d3654d93 100644
--- a/monai/_version.py
+++ b/monai/_version.py
@@ -6,7 +6,7 @@
# that just contains the computed version number.
# This file is released into the public domain. Generated by
-# versioneer-0.19 (https://github.com/python-versioneer/python-versioneer)
+# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)
"""Git implementation of _version.py."""
@@ -15,6 +15,8 @@
import re
import subprocess
import sys
+from typing import Callable, Dict
+import functools
def get_keywords():
@@ -52,8 +54,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
-LONG_VERSION_PY = {}
-HANDLERS = {}
+LONG_VERSION_PY: Dict[str, str] = {}
+HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
@@ -71,17 +73,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
- p = None
- for c in commands:
+ process = None
+
+ popen_kwargs = {}
+ if sys.platform == "win32":
+ # This hides the console window if pythonw.exe is used
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+ popen_kwargs["startupinfo"] = startupinfo
+
+ for command in commands:
try:
- dispcmd = str([c] + args)
+ dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen([c] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ process = subprocess.Popen([command] + args, cwd=cwd, env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr
+ else None), **popen_kwargs)
break
- except EnvironmentError:
+ except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
@@ -93,13 +103,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
- stdout = p.communicate()[0].strip().decode()
- if p.returncode != 0:
+ stdout = process.communicate()[0].strip().decode()
+ if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
- return None, p.returncode
- return stdout, p.returncode
+ return None, process.returncode
+ return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
@@ -111,15 +121,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
"""
rootdirs = []
- for i in range(3):
+ for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
- else:
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
+ rootdirs.append(root)
+ root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
@@ -136,22 +145,21 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords = {}
try:
- f = open(versionfile_abs, "r")
- for line in f.readlines():
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- f.close()
- except EnvironmentError:
+ with open(versionfile_abs, "r") as fobj:
+ for line in fobj:
+ if line.strip().startswith("git_refnames ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["refnames"] = mo.group(1)
+ if line.strip().startswith("git_full ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["full"] = mo.group(1)
+ if line.strip().startswith("git_date ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["date"] = mo.group(1)
+ except OSError:
pass
return keywords
@@ -159,8 +167,8 @@ def git_get_keywords(versionfile_abs):
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
- if not keywords:
- raise NotThisMethod("no keywords at all, weird")
+ if "refnames" not in keywords:
+ raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
@@ -179,11 +187,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = set([r.strip() for r in refnames.strip("()").split(",")])
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
+ tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -192,7 +200,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r'\d', r)])
+ tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -201,6 +209,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
+ # Filter out refs that exactly match prefix or that don't start
+ # with a number once the prefix is stripped (mostly a concern
+ # when prefix is '')
+ if not re.match(r'\d', r):
+ continue
if verbose:
print("picking %s" % r)
return {"version": r,
@@ -216,7 +229,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
+def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
@@ -227,8 +240,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ # GIT_DIR can interfere with correct operation of Versioneer.
+ # It may be intended to be passed to the Versioneer-versioned project,
+ # but that should not change where we get our version from.
+ env = os.environ.copy()
+ env.pop("GIT_DIR", None)
+ runner = functools.partial(runner, env=env)
+
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
+ hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -236,15 +256,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match", "%s*" % tag_prefix],
- cwd=root)
+ describe_out, rc = runner(GITS, [
+ "describe", "--tags", "--dirty", "--always", "--long",
+ "--match", f"{tag_prefix}[[:digit:]]*"
+ ], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
- full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
@@ -254,6 +274,39 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
+ cwd=root)
+ # --abbrev-ref was added in git-1.6.3
+ if rc != 0 or branch_name is None:
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
+ branch_name = branch_name.strip()
+
+ if branch_name == "HEAD":
+ # If we aren't exactly on a branch, pick a branch which represents
+ # the current commit. If all else fails, we are on a branchless
+ # commit.
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
+ # --contains was added in git-1.5.4
+ if rc != 0 or branches is None:
+ raise NotThisMethod("'git branch --contains' returned error")
+ branches = branches.split("\n")
+
+ # Remove the first line if we're running detached
+ if "(" in branches[0]:
+ branches.pop(0)
+
+ # Strip off the leading "* " from the list of branches.
+ branches = [branch[2:] for branch in branches]
+ if "master" in branches:
+ branch_name = "master"
+ elif not branches:
+ branch_name = None
+ else:
+ # Pick the first branch that is returned. Good or bad.
+ branch_name = branches[0]
+
+ pieces["branch"] = branch_name
+
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
@@ -270,7 +323,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
- # unparseable. Maybe git-describe is misbehaving?
+ # unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
return pieces
@@ -295,13 +348,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
- cwd=root)
- pieces["distance"] = int(count_out) # total number of commits
+ out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
+ pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
- cwd=root)[0].strip()
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
@@ -342,16 +393,64 @@ def render_pep440(pieces):
return rendered
+def render_pep440_branch(pieces):
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
+
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
+ (a feature branch will appear "older" than the master branch).
+
+ Exceptions:
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0"
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+untagged.%d.g%s" % (pieces["distance"],
+ pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def pep440_split_post(ver):
+ """Split pep440 version string at the post-release segment.
+
+ Returns the release segments before the post-release and the
+ post-release version number (or -1 if no post-release segment is present).
+ """
+ vc = str.split(ver, ".post")
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
+
+
def render_pep440_pre(pieces):
- """TAG[.post0.devDISTANCE] -- No -dirty.
+ """TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
if pieces["distance"]:
- rendered += ".post0.dev%d" % pieces["distance"]
+ # update the post release segment
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
+ rendered = tag_version
+ if post_version is not None:
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
+ else:
+ rendered += ".post0.dev%d" % (pieces["distance"])
+ else:
+ # no commits, use the tag as the version
+ rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
@@ -385,6 +484,35 @@ def render_pep440_post(pieces):
return rendered
+def render_pep440_post_branch(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
+
+ The ".dev0" means not master branch.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
@@ -461,10 +589,14 @@ def render(pieces, style):
if style == "pep440":
rendered = render_pep440(pieces)
+ elif style == "pep440-branch":
+ rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
+ elif style == "pep440-post-branch":
+ rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
@@ -500,7 +632,7 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'):
+ for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
diff --git a/monai/apps/auto3dseg/__init__.py b/monai/apps/auto3dseg/__init__.py
new file mode 100644
index 00000000000..7c335f48504
--- /dev/null
+++ b/monai/apps/auto3dseg/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auto_runner import AutoRunner
+from .bundle_gen import BundleAlgo, BundleGen
+from .data_analyzer import DataAnalyzer
+from .ensemble_builder import AlgoEnsemble, AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder
+from .hpo_gen import NNIGen, OptunaGen
+from .utils import export_bundle_algo_history, import_bundle_algo_history
diff --git a/monai/apps/auto3dseg/__main__.py b/monai/apps/auto3dseg/__main__.py
new file mode 100644
index 00000000000..eec56b75824
--- /dev/null
+++ b/monai/apps/auto3dseg/__main__.py
@@ -0,0 +1,32 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from monai.apps.auto3dseg.auto_runner import AutoRunner
+from monai.apps.auto3dseg.bundle_gen import BundleAlgo, BundleGen
+from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
+from monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder
+from monai.apps.auto3dseg.hpo_gen import NNIGen, OptunaGen
+
+if __name__ == "__main__":
+ from monai.utils import optional_import
+
+ fire, _ = optional_import("fire")
+ fire.Fire(
+ {
+ "DataAnalyzer": DataAnalyzer,
+ "BundleGen": BundleGen,
+ "BundleAlgo": BundleAlgo,
+ "AlgoEnsembleBuilder": AlgoEnsembleBuilder,
+ "AutoRunner": AutoRunner,
+ "NNIGen": NNIGen,
+ "OptunaGen": OptunaGen,
+ }
+ )
diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py
new file mode 100644
index 00000000000..138e751e993
--- /dev/null
+++ b/monai/apps/auto3dseg/auto_runner.py
@@ -0,0 +1,609 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import subprocess
+from copy import deepcopy
+from time import sleep
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+from monai.apps.auto3dseg.bundle_gen import BundleGen
+from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
+from monai.apps.auto3dseg.ensemble_builder import (
+ AlgoEnsemble,
+ AlgoEnsembleBestByFold,
+ AlgoEnsembleBestN,
+ AlgoEnsembleBuilder,
+)
+from monai.apps.auto3dseg.hpo_gen import NNIGen
+from monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history
+from monai.apps.utils import get_logger
+from monai.auto3dseg.utils import algo_to_pickle
+from monai.bundle import ConfigParser
+from monai.transforms import SaveImage
+from monai.utils.enums import AlgoEnsembleKeys
+from monai.utils.module import look_up_option, optional_import
+
+logger = get_logger(module_name=__name__)
+
+nni, has_nni = optional_import("nni")
+
+
+class AutoRunner:
+ """
+ An interface for handling Auto3Dseg with minimal inputs and understanding of the internal states in Auto3Dseg.
+ The users can run the Auto3Dseg with default settings in one line of code. They can also customize the advanced
+ features Auto3Dseg in a few additional lines. Examples of customization include
+
+ - change cross-validation folds
+ - change training/prediction parameters
+ - change ensemble methods
+ - automatic hyperparameter optimization.
+
+ The output of the interface is a directory that contains
+
+ - data statistics analysis report
+ - algorithm definition files (scripts, configs, pickle objects) and training results (checkpoints, accuracies)
+ - the predictions on the testing datasets from the final algorithm ensemble
+ - a copy of the input arguments in form of YAML
+ - cached intermediate results
+
+ Args:
+ work_dir: working directory to save the intermediate and final results.
+ input: the configuration dictionary or the file path to the configuration in form of YAML.
+ The configuration should contain datalist, dataroot, modality, multigpu, and class_names info.
+ analyze: on/off switch to run DataAnalyzer and generate a datastats report. If it is set to False,
+ The AutoRunner will attempt to skip datastats analysis and use cached results. If there is no such cache,
+ AutoRunner will report an error and stop.
+ algo_gen: on/off switch to run AlgoGen and generate templated BundleAlgos. If it is set to False,
+ The AutoRunner will attempt to skip the algorithm generation and stop if there is no cache to load.
+ train: on/off switch to run training and generate algorithm checkpoints. If it is set to False,
+ The AutoRunner will attempt to skip the training for all algorithms. If there is zero trained algorithm
+ but ``train`` is set to False, AutoRunner will stop.
+ hpo: use hyperparameter optimization (HPO) in the training phase. Users can provide a list of
+ hyper-parameter and a search will be performed to investigate the algorithm performances.
+ hpo_backend: a string that indicates the backend of the HPO. Currently, only NNI Grid-search mode
+ is supported
+ ensemble: on/off switch to run model ensemble and use the ensemble to predict outputs in testing
+ datasets.
+ not_use_cache: if the value is True, it will ignore all cached results in data analysis,
+ algorithm generation, or training, and start the pipeline from scratch.
+ kwargs: image writing parameters for the ensemble inference. The kwargs format follows the SaveImage
+ transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
+
+
+ Examples:
+ - User can use the one-liner to start the Auto3Dseg workflow
+
+ .. code-block:: bash
+
+ python -m monai.apps.auto3dseg AutoRunner run --input \
+ '{"modality": "ct", "datalist": "dl.json", "dataroot": "/dr", "multigpu": true, "class_names": ["A", "B"]}'
+
+ - User can also save the input dictionary as a input YAML file and use the following one-liner
+
+ .. code-block:: bash
+
+ python -m monai.apps.auto3dseg AutoRunner run --input=./input.yaml
+
+ - User can specify work_dir and data source config input and run AutoRunner:
+
+ .. code-block:: python
+
+ work_dir = "./work_dir"
+ input = "path_to_yaml_data_cfg"
+ runner = AutoRunner(work_dir=work_dir, input=input)
+ runner.run()
+
+ - User can specify training parameters by:
+
+ .. code-block:: python
+
+ input = "path_to_yaml_data_cfg"
+ runner = AutoRunner(input=input)
+ train_param = {
+ "CUDA_VISIBLE_DEVICES": [0],
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ }
+ runner.set_training_params(params=train_param) # 2 epochs
+ runner.run()
+
+ - User can specify the fold number of cross validation
+
+ .. code-block:: python
+
+ input = "path_to_yaml_data_cfg"
+ runner = AutoRunner(input=input)
+ runner.set_num_fold(n_fold = 2)
+ runner.run()
+
+ - User can specify the prediction parameters during algo ensemble inference:
+
+ .. code-block:: python
+
+ input = "path_to_yaml_data_cfg"
+ pred_params = {
+ 'files_slices': slice(0,2),
+ 'mode': "vote",
+ 'sigmoid': True,
+ }
+ runner = AutoRunner(input=input)
+ runner.set_prediction_params(params=pred_params)
+ runner.run()
+
+ - User can define a grid search space and use the HPO during training.
+
+ .. code-block:: python
+
+ input = "path_to_yaml_data_cfg"
+ pred_param = {
+ "CUDA_VISIBLE_DEVICES": [0],
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ }
+ runner = AutoRunner(input=input, hpo=True)
+ runner.set_nni_search_space({"learning_rate": {"_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1]}})
+ runner.run()
+
+ Notes:
+ Expected results in the work_dir as below::
+
+ work_dir/
+ ├── algorithm_templates # bundle algo templates (scripts/configs)
+ ├── cache.yaml # Autorunner will automatically cache results to save time
+ ├── datastats.yaml # datastats of the dataset
+ ├── dints_0 # network scripts/configs/checkpoints and pickle object of the algo
+ ├── ensemble_output # the prediction of testing datasets from the ensemble of the algos
+ ├── input.yaml # copy of the input data source configs
+ ├── segresnet_0 # network scripts/configs/checkpoints and pickle object of the algo
+ ├── segresnet2d_0 # network scripts/configs/checkpoints and pickle object of the algo
+ └── swinunetr_0 # network scripts/configs/checkpoints and pickle object of the algo
+
+ """
+
+ def __init__(
+ self,
+ work_dir: str = "./work_dir",
+ input: Union[Dict[str, Any], str, None] = None,
+ analyze: bool = True,
+ algo_gen: bool = True,
+ train: bool = True,
+ hpo: bool = False,
+ hpo_backend: str = "nni",
+ ensemble: bool = True,
+ not_use_cache: bool = False,
+ **kwargs,
+ ):
+ if not os.path.isdir(work_dir):
+ logger.info(f"{work_dir} does not exists. Creating...")
+ os.makedirs(work_dir)
+ logger.info(f"{work_dir} created to save all results")
+ else:
+ logger.info(f"Work directory {work_dir} is used to save all results")
+
+ self.work_dir = os.path.abspath(work_dir)
+ self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
+
+ if input is None:
+ input = os.path.join(self.work_dir, "input.yaml")
+ elif isinstance(input, Dict):
+ ConfigParser.export_config_file(input, self.data_src_cfg_name)
+ self.data_src_cfg = input
+ elif isinstance(input, str) and os.path.isfile(input):
+ shutil.copy(input, self.data_src_cfg_name)
+ logger.info(f"Loading {input} for AutoRunner and making a copy in {self.data_src_cfg_name}")
+ self.data_src_cfg = ConfigParser.load_config_file(self.data_src_cfg_name)
+ else:
+ raise ValueError(f"{input} is not a valid file")
+
+ self.not_use_cache = not_use_cache
+ self.cache_filename = os.path.join(self.work_dir, "cache.yaml")
+ self.cache = self.check_cache()
+ self.export_cache()
+
+ # Whether we need all the steps or not
+ self.analyze = self.check_analyze(analyze)
+ self.algo_gen = self.check_algo_gen(algo_gen)
+ self.train = self.check_train(train)
+ self.ensemble = ensemble # last step, no need to check
+
+ # intermediate variables
+ self.dataroot = self.data_src_cfg["dataroot"]
+ self.datalist_filename = self.data_src_cfg["datalist"]
+ self.datastats_filename = os.path.join(self.work_dir, "datastats.yaml")
+ self.set_training_params()
+ self.set_num_fold()
+ self.set_prediction_params()
+ self.save_image = self.set_image_save_transform(kwargs)
+ self.ensemble_method: AlgoEnsemble
+ self.set_ensemble_method()
+
+ # hpo
+ if hpo_backend.lower() != "nni":
+ raise NotImplementedError("HPOGen backend only supports NNI")
+ self.hpo = hpo and has_nni
+ self.set_hpo_params()
+ self.search_space: Dict[str, Dict[str, Any]] = {}
+ self.hpo_tasks = 0
+
+ def check_cache(self):
+ """
+ Check if the intermediate result is cached after each step in the current working directory
+
+ Returns:
+ a dict of cache results. If not_use_cache is set to True, or there is no cache file in the
+ working directory, the result will be ``empty_cache`` in which all ``has_cache`` keys are
+ set to False.
+ """
+ empty_cache = {
+ "analyze": {"has_cache": False, "datastats": None},
+ "algo_gen": {"has_cache": False},
+ "train": {"has_cache": False},
+ }
+ if self.not_use_cache or not os.path.isfile(self.cache_filename):
+ return empty_cache
+
+ cache = ConfigParser.load_config_file(self.cache_filename)
+
+ if cache["analyze"]["has_cache"]:
+ # check if the file in the right format and exists.
+ if not isinstance(cache["analyze"]["datastats"], str):
+ cache["analyze"] = False
+ cache["analyze"]["datastats"] = None
+
+ if not os.path.isfile(cache["analyze"]["datastats"]):
+ cache["analyze"]["has_cache"] = False
+
+ if cache["algo_gen"]["has_cache"]:
+ history = import_bundle_algo_history(self.work_dir, only_trained=False)
+ if len(history) == 0: # no saved algo_objects
+ cache["algo_gen"]["has_cache"] = False
+
+ if cache["train"]["has_cache"]:
+ trained_history = import_bundle_algo_history(self.work_dir, only_trained=True)
+ if len(trained_history) == 0:
+ cache["train"]["has_cache"] = False
+
+ return cache
+
+ def export_cache(self):
+ """
+ Save the cache state as ``cache.yaml`` in the working directory
+ """
+ ConfigParser.export_config_file(self.cache, self.cache_filename, fmt="yaml", default_flow_style=None)
+
+ def check_analyze(self, analyze: bool):
+ """Check if the AutoRunner can skip data analysis."""
+
+ if self.cache["analyze"]["has_cache"]:
+ return False # we can use cached result
+
+ if analyze:
+ return True # we need to do analysis
+ else:
+ raise ValueError(
+ f"cache data analysis report is not found in {self.work_dir}"
+ "or the cache.yaml file is missing in the directory"
+ )
+
+ def check_algo_gen(self, algo_gen: bool):
+ """Check if the AutoRunner can skip AlgoGen/BundleGen."""
+
+ if self.cache["algo_gen"]["has_cache"]:
+ return False # we can use cached result
+
+ if algo_gen:
+ return True # we need to do algo_gen
+ else:
+ raise ValueError(
+ f"algo_object.pkl is not found in the task folders under {self.work_dir}"
+ "or the cache.yaml file is missing in the directory"
+ )
+
+ def check_train(self, train: bool):
+ """Check if the AutoRunner can skip training."""
+
+ if self.cache["train"]["has_cache"]:
+ return False # we can use cached result
+
+ if train:
+ return True # we need to do training
+ else:
+ raise ValueError(
+ f"algo_object.pkl in the task folders under {self.work_dir} has no [best_metrics] key"
+ "or the cache.yaml file is missing in the directory"
+ )
+
+ def set_num_fold(self, num_fold: int = 5):
+ """
+ Set the number of cross validation folds for all algos.
+
+ Args:
+ num_fold: a positive integer to define the number of folds.
+ """
+ if num_fold <= 0:
+ raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
+ self.num_fold = num_fold
+
+ def set_training_params(self, params: Optional[Dict[str, Any]] = None):
+ """
+ Set the training params for all algos.
+
+ Args:
+ params: a dict that defines the overriding key-value pairs during training. The overriding method
+ is defined by the algo class.
+
+ Examples:
+ For BundleAlgo objects, the training parameter to shorten the training time to a few epochs can be
+ {"num_iterations": 8, "num_iterations_per_validation": 4}
+
+ """
+ if params is None:
+ self.train_params = {}
+ else:
+ self.train_params = deepcopy(params)
+
+ def set_prediction_params(self, params: Optional[Dict[str, Any]] = None):
+ """
+ Set the prediction params for all algos.
+
+ Args:
+ params: a dict that defines the overriding key-value pairs during prediction. The overriding method
+ is defined by the algo class.
+
+ Examples:
+
+ For BundleAlgo objects, this set of param will specify the algo ensemble to only inference the first
+ two files in the testing datalist {"file_slices": slice(0, 2)}
+
+ """
+ if params is None:
+ self.pred_params = {"sigmoid": True} # output will be 0-1
+ else:
+ self.pred_params = deepcopy(params)
+
+ def set_hpo_params(self, params: Optional[Dict[str, Any]] = None):
+ """
+ Set parameters for the HPO module and the algos before the training. It will attempt to (1) override bundle
+ templates with the key-value pairs in ``params`` (2) change the config of the HPO module (e.g. NNI) if the
+ key is found to be one of:
+
+ - "trialCodeDirectory"
+ - "trialGpuNumber"
+ - "trialConcurrency"
+ - "maxTrialNumber"
+ - "maxExperimentDuration"
+ - "tuner"
+ - "trainingService"
+
+ Args:
+ params: a dict that defines the overriding key-value pairs during instantiation of the algo. For
+ BundleAlgo, it will override the template config filling.
+ """
+ if params is None:
+ self.hpo_params = self.train_params
+ else:
+ self.hpo_params = params
+
+ def set_nni_search_space(self, search_space):
+ """
+ Set the search space for NNI parameter search.
+
+ Args:
+ search_space: hyper parameter search space in the form of dict. For more information, please check
+ NNI documentation: https://nni.readthedocs.io/en/v2.2/Tutorial/SearchSpaceSpec.html .
+ """
+ value_combinations = 1
+ for k, v in search_space.items():
+ if "_value" not in v:
+ raise ValueError(f"{search_space} key {k} value {v} has not _value")
+ value_combinations *= len(v["_value"])
+
+ self.search_space = search_space
+ self.hpo_tasks = value_combinations
+
+ def set_image_save_transform(self, kwargs):
+ """
+ Set the ensemble output transform.
+
+ Args:
+ kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
+ transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .
+
+ """
+
+ if "output_dir" in kwargs:
+ output_dir = kwargs.pop("output_dir")
+ else:
+ output_dir = os.path.join(self.work_dir, "ensemble_output")
+ logger.info(f"The output_dir is not specified. {output_dir} will be used to save ensemble predictions")
+
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+ logger.info(f"Directory {output_dir} is created to save ensemble predictions")
+
+ if "output_postfix" in kwargs:
+ output_postfix = kwargs.pop("output_postfix")
+ else:
+ output_postfix = "ensemble"
+
+ self.output_dir = output_dir
+ return SaveImage(output_dir=output_dir, output_postfix=output_postfix, **kwargs)
+
+ def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestN", **kwargs):
+ """
+ Set the bundle ensemble method
+
+ Args:
+ ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
+ and "AlgoEnsembleBestByFold".
+ kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
+ ``AlgoEnsembleBestN`` is supported.
+
+ """
+ self.ensemble_method_name = look_up_option(
+ ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
+ )
+ if self.ensemble_method_name == "AlgoEnsembleBestN":
+ n_best = kwargs.pop("n_best", False)
+ n_best = 2 if not n_best else n_best
+ self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
+ elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
+ self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold)
+ else:
+ raise NotImplementedError(f"Ensemble method {self.ensemble_method_name} is not implemented.")
+
+ def _train_algo_in_sequence(self, history: List[Dict[str, Any]]):
+ """
+ Train the Algos in a sequential scheme. The order of training is randomized.
+
+ Args:
+ history: the history of generated Algos. It is a list of dicts. Each element has the task name
+ (e.g. "dints_0" for dints network in fold 0) as the key and the algo object as the value.
+ After the training, the algo object with the ``best_metric`` will be saved as a pickle file.
+
+ Note:
+ The final results of the model training will be written to all the generated algorithm's output
+ folders under the working directory. The results include the model checkpoints, a
+ progress.yaml, accuracies in CSV and a pickle file of the Algo object.
+ """
+ for task in history:
+ for _, algo in task.items():
+ algo.train(self.train_params)
+ acc = algo.get_score()
+ algo_to_pickle(algo, template_path=algo.template_path, best_metrics=acc)
+
+ def _train_algo_in_nni(self, history):
+ """
+ Train the Algos using HPO.
+
+ Args:
+ history: the history of generated Algos. It is a list of dicts. Each element has the task name
+ (e.g. "dints_0" for dints network in fold 0) as the key and the algo object as the value.
+ After the training, the algo object with the ``best_metric`` will be saved as a pickle file.
+
+ Note:
+ The final results of the model training will not be written to all the previously generated
+ algorithm's output folders. Instead, HPO will generate a new algo during the searching, and
+ the new algo will be saved under the working directory with a different format of the name.
+ For example, if the searching space has "learning_rate", the result of HPO will be written to
+ a folder name with original task name and the param (e.g. "dints_0_learning_rate_0.001").
+ The results include the model checkpoints, a progress.yaml, accuracies in CSV and a pickle
+ file of the Algo object.
+
+ """
+ default_nni_config = {
+ "trialCodeDirectory": ".",
+ "trialGpuNumber": torch.cuda.device_count(),
+ "trialConcurrency": 1,
+ "maxTrialNumber": 10,
+ "maxExperimentDuration": "1h",
+ "tuner": {"name": "GridSearch"},
+ "trainingService": {"platform": "local", "useActiveGpu": True},
+ }
+
+ last_total_tasks = len(import_bundle_algo_history(self.work_dir, only_trained=True))
+ for task in history:
+ for name, algo in task.items():
+ nni_gen = NNIGen(algo=algo, params=self.hpo_params)
+ obj_filename = nni_gen.get_obj_filename()
+ nni_config = deepcopy(default_nni_config)
+ # override the default nni config with the same key in hpo_params
+ for key in self.hpo_params:
+ if key in nni_config:
+ nni_config[key] = self.hpo_params[key]
+ nni_config.update({"experimentName": name})
+ nni_config.update({"search_space": self.search_space})
+ trial_cmd = "python -m monai.apps.auto3dseg NNIGen run_algo " + obj_filename + " " + self.work_dir
+ nni_config.update({"trialCommand": trial_cmd})
+ nni_config_filename = os.path.abspath(os.path.join(self.work_dir, "nni_config.yaml"))
+ ConfigParser.export_config_file(nni_config, nni_config_filename, fmt="yaml", default_flow_style=None)
+
+ max_trial = min(self.hpo_tasks, default_nni_config["maxTrialNumber"])
+ cmd = "nnictl create --config " + nni_config_filename + " --port 8088"
+ subprocess.run(cmd.split(), check=True)
+
+ n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
+ while n_trainings - last_total_tasks < max_trial:
+ sleep(1)
+ n_trainings = len(import_bundle_algo_history(self.work_dir, only_trained=True))
+
+ cmd = "nnictl stop --all"
+ subprocess.run(cmd.split(), check=True)
+ logger.info(f"NNI completes HPO on {name}")
+ last_total_tasks = n_trainings
+
+ def run(self):
+ """
+ Run the AutoRunner pipeline
+ """
+ # step 1: data analysis
+ if self.analyze:
+ da = DataAnalyzer(self.datalist_filename, self.dataroot, output_path=self.datastats_filename)
+ da.get_all_case_stats()
+ self.cache["analyze"]["has_cache"] = True
+ self.cache["analyze"]["datastats"] = self.datastats_filename
+ self.export_cache()
+ else:
+ logger.info("Found cached results and skipping data analysis...")
+
+ # step 2: algorithm generation
+ if self.algo_gen:
+ bundle_generator = BundleGen(
+ algo_path=self.work_dir,
+ data_stats_filename=self.datastats_filename,
+ data_src_cfg_name=self.data_src_cfg_name,
+ )
+
+ bundle_generator.generate(self.work_dir, num_fold=self.num_fold)
+ history = bundle_generator.get_history()
+ export_bundle_algo_history(history)
+ self.cache["algo_gen"]["has_cache"] = True
+ self.export_cache()
+ else:
+ logger.info("Found cached results and skipping algorithm generation...")
+
+ # step 3: algo training
+ if self.train:
+ history = import_bundle_algo_history(self.work_dir, only_trained=False)
+ if not self.hpo:
+ self._train_algo_in_sequence(history)
+ else:
+ self._train_algo_in_nni(history)
+ self.cache["train"]["has_cache"] = True
+ self.export_cache()
+ else:
+ logger.info("Found cached results and skipping algorithm training...")
+
+ # step 4: model ensemble and write the prediction to disks.
+ if self.ensemble:
+ history = import_bundle_algo_history(self.work_dir, only_trained=True)
+ builder = AlgoEnsembleBuilder(history, self.data_src_cfg_name)
+ builder.set_ensemble_method(self.ensemble_method)
+ ensembler = builder.get_ensemble()
+ preds = ensembler(pred_param=self.pred_params) # apply sigmoid to binarize the prediction
+ print("Auto3Dseg picked the following networks to ensemble:")
+ for algo in ensembler.get_algo_ensemble():
+ print(algo[AlgoEnsembleKeys.ID])
+
+ for pred in preds:
+ self.save_image(pred)
+ logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")
+
+ logger.info("Auto3Dseg pipeline is complete successfully.")
diff --git a/monai/apps/auto3dseg/bundle_gen.py b/monai/apps/auto3dseg/bundle_gen.py
new file mode 100644
index 00000000000..6c04b2e7e2f
--- /dev/null
+++ b/monai/apps/auto3dseg/bundle_gen.py
@@ -0,0 +1,405 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import os
+import shutil
+import subprocess
+import sys
+from copy import deepcopy
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Any, Dict, List, Mapping
+
+import torch
+
+from monai.apps import download_and_extract
+from monai.apps.utils import get_logger
+from monai.auto3dseg.algo_gen import Algo, AlgoGen
+from monai.auto3dseg.utils import algo_to_pickle
+from monai.bundle.config_parser import ConfigParser
+from monai.utils import ensure_tuple
+
+logger = get_logger(module_name=__name__)
+ALGO_HASH = os.environ.get("MONAI_ALGO_HASH", "d7bf36c")
+
+__all__ = ["BundleAlgo", "BundleGen"]
+
+
+class BundleAlgo(Algo):
+ """
+ An algorithm represented by a set of bundle configurations and scripts.
+
+ ``BundleAlgo.cfg`` is a ``monai.bundle.ConfigParser`` instance.
+
+ .. code-block:: python
+
+ from monai.apps.auto3dseg import BundleAlgo
+
+ data_stats_yaml = "/workspace/data_stats.yaml"
+ algo = BundleAlgo(template_path=../algorithms/templates/segresnet2d/configs)
+ algo.set_data_stats(data_stats_yaml)
+ # algo.set_data_src("../data_src.json")
+ algo.export_to_disk(".", algo_name="segresnet2d_1")
+
+ This class creates MONAI bundles from a directory of 'bundle template'. Different from the regular MONAI bundle
+ format, the bundle template may contain placeholders that must be filled using ``fill_template_config`` during
+ ``export_to_disk``. Then created bundle keeps the same file structure as the template.
+
+ """
+
+ def __init__(self, template_path: str):
+ """
+ Create an Algo instance based on the predefined Algo template.
+
+ Args:
+ template_path: path to the root of the algo template.
+
+ """
+
+ self.template_path = template_path
+ self.data_stats_files = ""
+ self.data_list_file = ""
+ self.output_path = ""
+ self.name = ""
+ self.best_metric = None
+ # track records when filling template config: {"": {"": value, ...}, ...}
+ self.fill_records: dict = {}
+
+ def set_data_stats(self, data_stats_files: str):
+ """
+ Set the data analysis report (generated by DataAnalyzer).
+
+ Args:
+ data_stats_files: path to the datastats yaml file
+ """
+ self.data_stats_files = data_stats_files
+
+ def set_data_source(self, data_src_cfg: str):
+ """
+ Set the data source configuration file
+
+ Args:
+ data_src_cfg: path to a configuration file (yaml) that contains datalist, dataroot, and other params.
+ The config will be in a form of {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot":
+ "path_dir_data"}
+ """
+ self.data_list_file = data_src_cfg
+
+ def fill_template_config(self, data_stats_filename: str, algo_path: str, **kwargs) -> dict:
+ """
+ The configuration files defined when constructing this Algo instance might not have a complete training
+ and validation pipelines. Some configuration components and hyperparameters of the pipelines depend on the
+ training data and other factors. This API is provided to allow the creation of fully functioning config files.
+ Return the records of filling template config: {"": {"": value, ...}, ...}.
+
+ Args:
+ data_stats_filename: filename of the data stats report (generated by DataAnalyzer)
+
+ Notes:
+ Template filling is optional. The user can construct a set of pre-filled configs without replacing values
+ by using the data analysis results. It is also intended to be re-implemented in subclasses of BundleAlgo
+ if the user wants their own way of auto-configured template filling.
+ """
+ return {}
+
+ def export_to_disk(self, output_path: str, algo_name: str, **kwargs):
+ """
+ Fill the configuration templates, write the bundle (configs + scripts) to folder `output_path/algo_name`.
+
+ Args:
+ output_path: Path to export the 'scripts' and 'configs' directories.
+ algo_name: the identifier of the algorithm (usually contains the name and extra info like fold ID).
+ kwargs: other parameters, including: "copy_dirs=True/False" means whether to copy the template as output
+ instead of inplace operation, "fill_template=True/False" means whether to fill the placeholders
+ in the template. other parameters are for `fill_template_config` function.
+
+ """
+ if kwargs.pop("copy_dirs", True):
+ self.output_path = os.path.join(output_path, algo_name)
+ os.makedirs(self.output_path, exist_ok=True)
+ if os.path.isdir(self.output_path):
+ shutil.rmtree(self.output_path)
+ shutil.copytree(self.template_path, self.output_path)
+ else:
+ self.output_path = self.template_path
+ if kwargs.pop("fill_template", True):
+ self.fill_records = self.fill_template_config(self.data_stats_files, self.output_path, **kwargs)
+ logger.info(self.output_path)
+
+ def _create_cmd(self, train_params=None):
+ """
+ Create the command to execute training.
+
+ """
+ if train_params is not None:
+ params = deepcopy(train_params)
+
+ train_py = os.path.join(self.output_path, "scripts", "train.py")
+ config_dir = os.path.join(self.output_path, "configs")
+
+ if os.path.isdir(config_dir):
+ base_cmd = ""
+ for file in os.listdir(config_dir):
+ if len(base_cmd) == 0:
+ base_cmd += f"{train_py} run --config_file="
+ else:
+ base_cmd += "," # Python Fire does not accept space
+ # Python Fire may be confused by single-quoted WindowsPath
+ config_yaml = Path(os.path.join(config_dir, file)).as_posix()
+ base_cmd += f"'{config_yaml}'"
+
+ if "CUDA_VISIBLE_DEVICES" in params:
+ devices = params.pop("CUDA_VISIBLE_DEVICES")
+ n_devices, devices_info = len(devices), ",".join([str(x) for x in devices])
+ else:
+ n_devices, devices_info = torch.cuda.device_count(), ""
+ if n_devices > 1:
+ cmd = f"torchrun --nnodes={1:d} --nproc_per_node={n_devices:d} "
+ else:
+ cmd = "python " # TODO: which system python?
+ cmd += base_cmd
+ if params and isinstance(params, Mapping):
+ for k, v in params.items():
+ cmd += f" --{k}={v}"
+ return cmd, devices_info
+
+ def _run_cmd(self, cmd: str, devices_info: str):
+ """
+ Execute the training command with target devices information.
+
+ """
+ try:
+ logger.info(f"Launching: {cmd}")
+ ps_environ = os.environ.copy()
+ if devices_info:
+ ps_environ["CUDA_VISIBLE_DEVICES"] = devices_info
+ normal_out = subprocess.run(cmd.split(), env=ps_environ, check=True, capture_output=True)
+ logger.info(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t"))
+ except subprocess.CalledProcessError as e:
+ output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t")
+ errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")
+ raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e
+ return normal_out
+
+ def train(self, train_params=None):
+ """
+ Load the run function in the training script of each model. Training parameter is predefined by the
+ algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.
+
+ Args:
+ train_params: to specify the devices using a list of integers: ``{"CUDA_VISIBLE_DEVICES": [1,2,3]}``.
+ """
+ cmd, devices_info = self._create_cmd(train_params)
+ return self._run_cmd(cmd, devices_info)
+
+ def get_score(self, *args, **kwargs):
+ """
+ Returns validation scores of the model trained by the current Algo.
+ """
+ config_yaml = os.path.join(self.output_path, "configs", "hyper_parameters.yaml")
+ parser = ConfigParser()
+ parser.read_config(config_yaml)
+ ckpt_path = parser.get_parsed_content("ckpt_path", default=self.output_path)
+
+ dict_file = ConfigParser.load_config_file(os.path.join(ckpt_path, "progress.yaml"))
+ # dict_file: a list of scores saved in the form of dict in progress.yaml
+ return dict_file[-1]["best_avg_dice_score"] # the last one is the best one
+
+ def get_inferer(self, *args, **kwargs):
+ """
+ Load the InferClass from the infer.py. The InferClass should be defined in the template under the path of
+ `"scripts/infer.py"`. It is required to define the "InferClass" (name is fixed) with two functions at least
+ (``__init__`` and ``infer``). The init class has an override kwargs that can be used to override parameters in
+ the run-time optionally.
+
+ Examples:
+
+ .. code-block:: python
+
+ class InferClass
+ def __init__(self, config_file: Optional[Union[str, Sequence[str]]] = None, **override):
+ # read configs from config_file (sequence)
+ # set up transforms
+ # set up model
+ # set up other hyper parameters
+ return
+
+ @torch.no_grad()
+ def infer(self, image_file):
+ # infer the model and save the results to output
+ return output
+
+ """
+ infer_py = os.path.join(self.output_path, "scripts", "infer.py")
+ if not os.path.isfile(infer_py):
+ raise ValueError(f"{infer_py} is not found, please check the path.")
+
+ config_dir = os.path.join(self.output_path, "configs")
+ configs_path = [os.path.join(config_dir, f) for f in os.listdir(config_dir)]
+
+ spec = importlib.util.spec_from_file_location("InferClass", infer_py)
+ infer_class = importlib.util.module_from_spec(spec)
+ sys.modules["InferClass"] = infer_class
+ spec.loader.exec_module(infer_class)
+ return infer_class.InferClass(configs_path, *args, **kwargs)
+
+ def predict(self, predict_params=None):
+ """
+ Use the trained model to predict the outputs with a given input image. Path to input image is in the params
+ dict in a form of {"files", ["path_to_image_1", "path_to_image_2"]}. If it is not specified, then the
+ prediction will use the test images predefined in the bundle config.
+
+ Args:
+ predict_params: a dict to override the parameters in the bundle config (including the files to predict).
+
+ """
+ if predict_params is None:
+ params = {}
+ else:
+ params = deepcopy(predict_params)
+
+ files = params.pop("files", ".")
+ inferer = self.get_inferer(**params)
+ return [inferer.infer(f) for f in ensure_tuple(files)]
+
+ def get_output_path(self):
+ """Returns the algo output paths to find the algo scripts and configs."""
+ return self.output_path
+
+
+# path to download the algo_templates
+default_algo_zip = (
+ f"https://github.com/Project-MONAI/research-contributions/releases/download/algo_templates/{ALGO_HASH}.tar.gz"
+)
+
+# default algorithms
+default_algos = {
+ "segresnet2d": dict(_target_="segresnet2d.scripts.algo.Segresnet2dAlgo", template_path="segresnet2d"),
+ "dints": dict(_target_="dints.scripts.algo.DintsAlgo", template_path="dints"),
+ "swinunetr": dict(_target_="swinunetr.scripts.algo.SwinunetrAlgo", template_path="swinunetr"),
+ "segresnet": dict(_target_="segresnet.scripts.algo.SegresnetAlgo", template_path="segresnet"),
+}
+
+
+class BundleGen(AlgoGen):
+ """
+ This class generates a set of bundles according to the cross-validation folds, each of them can run independently.
+
+ Args:
+ algo_path: the directory path to save the algorithm templates. Default is the current working dir.
+ algos: if dictionary, it outlines the algorithm to use. if None, automatically download the zip file
+ from the default link. if string, it represents the download link.
+ The current default options are released at:
+ https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg
+ data_stats_filename: the path to the data stats file (generated by DataAnalyzer)
+ data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
+ {"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}
+
+ .. code-block:: bash
+
+ python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/data_stats.yaml"
+ """
+
+ def __init__(self, algo_path: str = ".", algos=None, data_stats_filename=None, data_src_cfg_name=None):
+ self.algos: Any = []
+
+ if algos is None or isinstance(algos, str):
+ # trigger the download process
+ zip_download_dir = TemporaryDirectory()
+ algo_compressed_file = os.path.join(zip_download_dir.name, "algo_templates.tar.gz")
+ download_and_extract(default_algo_zip if algos is None else algos, algo_compressed_file, algo_path)
+ zip_download_dir.cleanup()
+ sys.path.insert(0, os.path.join(algo_path, "algorithm_templates"))
+ algos = deepcopy(default_algos)
+ for name in algos:
+ algos[name]["template_path"] = os.path.join(
+ algo_path, "algorithm_templates", algos[name]["template_path"]
+ )
+
+ if isinstance(algos, dict):
+ for algo_name, algo_params in algos.items():
+ try:
+ self.algos.append(ConfigParser(algo_params).get_parsed_content())
+ except RuntimeError as e:
+ if "ModuleNotFoundError" in str(e):
+ msg = """Please make sure the folder structure of an Algo Template follows
+ [algo_name]
+ ├── configs
+ │ ├── hyperparameters.yaml # automatically generated yaml from a set of ``template_configs``
+ │ ├── network.yaml # automatically generated network yaml from a set of ``template_configs``
+ │ ├── transforms_train.yaml # automatically generated yaml to define transforms for training
+ │ ├── transforms_validate.yaml # automatically generated yaml to define transforms for validation
+ │ └── transforms_infer.yaml # automatically generated yaml to define transforms for inference
+ └── scripts
+ ├── test.py
+ ├── __init__.py
+ └── validate.py
+ """
+ raise RuntimeError(msg) from e
+ self.algos[-1].name = algo_name
+ else:
+ self.algos = ensure_tuple(algos)
+
+ self.data_stats_filename = data_stats_filename
+ self.data_src_cfg_filename = data_src_cfg_name
+ self.history: List[Dict] = []
+
+ def set_data_stats(self, data_stats_filename: str):
+ """
+ Set the data stats filename
+
+ Args:
+ data_stats_filename: filename of datastats
+ """
+ self.data_stats_filename = data_stats_filename
+
+ def get_data_stats(self):
+ """Get the filename of the data stats"""
+ return self.data_stats_filename
+
+ def set_data_src(self, data_src_cfg_filename):
+ """
+ Set the data source filename
+
+ Args:
+ data_src_cfg_filename: filename of data_source file
+ """
+ self.data_src_cfg_filename = data_src_cfg_filename
+
+ def get_data_src(self):
+ """Get the data source filename"""
+ return self.data_src_cfg_filename
+
+ def get_history(self) -> List:
+ """get the history of the bundleAlgo object with their names/identifiers"""
+ return self.history
+
+ def generate(self, output_folder=".", num_fold: int = 5):
+ """
+ Generate the bundle scripts/configs for each bundleAlgo
+
+ Args:
+ output_folder: the output folder to save each algorithm.
+ num_fold: the number of cross validation fold
+ """
+ fold_idx = list(range(num_fold))
+ for algo in self.algos:
+ for f_id in ensure_tuple(fold_idx):
+ data_stats = self.get_data_stats()
+ data_src_cfg = self.get_data_src()
+ gen_algo = deepcopy(algo)
+ gen_algo.set_data_stats(data_stats)
+ gen_algo.set_data_source(data_src_cfg)
+ name = f"{gen_algo.name}_{f_id}"
+ gen_algo.export_to_disk(output_folder, name, fold=f_id)
+ algo_to_pickle(gen_algo, template_path=algo.template_path)
+ self.history.append({name: gen_algo}) # track the previous, may create a persistent history
diff --git a/monai/apps/auto3dseg/data_analyzer.py b/monai/apps/auto3dseg/data_analyzer.py
new file mode 100644
index 00000000000..fea04a994e1
--- /dev/null
+++ b/monai/apps/auto3dseg/data_analyzer.py
@@ -0,0 +1,275 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from os import path
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import torch
+
+from monai.apps.utils import get_logger
+from monai.auto3dseg import SegSummarizer
+from monai.auto3dseg.utils import datafold_read
+from monai.bundle import config_parser
+from monai.bundle.config_parser import ConfigParser
+from monai.data import DataLoader, Dataset
+from monai.data.utils import no_collation
+from monai.transforms import (
+ Compose,
+ EnsureChannelFirstd,
+ EnsureTyped,
+ Lambdad,
+ LoadImaged,
+ Orientationd,
+ SqueezeDimd,
+ ToDeviced,
+)
+from monai.utils import StrEnum, min_version, optional_import
+from monai.utils.enums import DataStatsKeys, ImageStatsKeys
+
+
+def strenum_representer(dumper, data):
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data.value)
+
+
+if optional_import("yaml")[1]:
+ config_parser.yaml.SafeDumper.add_multi_representer(StrEnum, strenum_representer)
+
+tqdm, has_tqdm = optional_import("tqdm", "4.47.0", min_version, "tqdm")
+logger = get_logger(module_name=__name__)
+
+__all__ = ["DataAnalyzer"]
+
+
+def _argmax_if_multichannel(x):
+ return torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x
+
+
+class DataAnalyzer:
+ """
+ The DataAnalyzer automatically analyzes given medical image dataset and reports the statistics.
+ The module expects file paths to the image data and utilizes the LoadImaged transform to read the
+ files, which supports nii, nii.gz, png, jpg, bmp, npz, npy, and dcm formats. Currently, only
+ segmentation task is supported, so the user needs to provide paths to the image and label files
+ (if have). Also, label data format is preferred to be (1,H,W,D), with the label index in the
+ first dimension. If it is in onehot format, it will be converted to the preferred format.
+
+ Args:
+ datalist: a Python dictionary storing group, fold, and other information of the medical
+ image dataset, or a string to the JSON file storing the dictionary.
+ dataroot: user's local directory containing the datasets.
+ output_path: path to save the analysis result.
+ average: whether to average the statistical value across different image modalities.
+ do_ccp: apply the connected component algorithm to process the labels/images
+ device: a string specifying hardware (CUDA/CPU) utilized for the operations.
+ worker: number of workers to use for parallel processing. If device is cuda/GPU, worker has
+ to be 0.
+ image_key: a string that user specify for the image. The DataAnalyzer will look it up in the
+ datalist to locate the image files of the dataset.
+ label_key: a string that user specify for the label. The DataAnalyzer will look it up in the
+ datalist to locate the label files of the dataset. If label_key is NoneType or "None",
+ the DataAnalyzer will skip looking for labels and all label-related operations.
+ hist_bins: bins to compute histogram for each image channel.
+ hist_range: ranges to compute histogram for each image channel.
+ fmt: format used to save the analysis results. Defaults to "yaml".
+ histogram_only: whether to only compute histograms. Defaults to False.
+
+ Raises:
+ ValueError if device is GPU and worker > 0.
+
+ Examples:
+ .. code-block:: python
+
+ from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
+
+ datalist = {
+ "testing": [{"image": "image_003.nii.gz"}],
+ "training": [
+ {"fold": 0, "image": "image_001.nii.gz", "label": "label_001.nii.gz"},
+ {"fold": 0, "image": "image_002.nii.gz", "label": "label_002.nii.gz"},
+ {"fold": 1, "image": "image_001.nii.gz", "label": "label_001.nii.gz"},
+ {"fold": 1, "image": "image_004.nii.gz", "label": "label_004.nii.gz"},
+ ],
+ }
+
+ dataroot = '/datasets' # the directory where you have the image files (nii.gz)
+ DataAnalyzer(datalist, dataroot)
+
+ Notes:
+ The module can also be called from the command line interface (CLI).
+
+ For example:
+
+ .. code-block:: bash
+
+ python -m monai.apps.auto3dseg \\
+ DataAnalyzer \\
+ get_all_case_stats \\
+ --datalist="my_datalist.json" \\
+ --dataroot="my_dataroot_dir"
+
+ """
+
+ def __init__(
+ self,
+ datalist: Union[str, Dict],
+ dataroot: str = "",
+ output_path: str = "./data_stats.yaml",
+ average: bool = True,
+ do_ccp: bool = True,
+ device: Union[str, torch.device] = "cpu",
+ worker: int = 2,
+ image_key: str = "image",
+ label_key: Optional[str] = "label",
+ hist_bins: Optional[Union[list, int]] = 0,
+ hist_range: Optional[list] = None,
+ fmt: Optional[str] = "yaml",
+ histogram_only: bool = False,
+ ):
+ if path.isfile(output_path):
+ warnings.warn(f"File {output_path} already exists and will be overwritten.")
+ logger.debug(f"{output_path} will be overwritten by a new datastat.")
+
+ self.datalist = datalist
+ self.dataroot = dataroot
+ self.output_path = output_path
+ self.average = average
+ self.do_ccp = do_ccp
+ self.device = torch.device(device)
+ self.worker = 0 if (self.device.type == "cuda") else worker
+ self.image_key = image_key
+ self.label_key = None if label_key == "None" else label_key
+ self.hist_bins = hist_bins
+ self.hist_range: list = [-500, 500] if hist_range is None else hist_range
+ self.fmt = fmt
+ self.histogram_only = histogram_only
+
+ @staticmethod
+ def _check_data_uniformity(keys: List[str], result: Dict):
+ """
+ Check data uniformity since DataAnalyzer provides no support to multi-modal images with different
+ affine matrices/spacings due to monai transforms.
+
+ Args:
+ keys: a list of string-type keys under image_stats dictionary.
+
+ Returns:
+ False if one of the selected key values is not constant across the dataset images.
+
+ """
+
+ constant_props = [result[DataStatsKeys.SUMMARY][DataStatsKeys.IMAGE_STATS][key] for key in keys]
+ for prop in constant_props:
+ if "stdev" in prop and np.any(prop["stdev"]):
+ return False
+
+ return True
+
+ def get_all_case_stats(self, key="training", transform_list=None):
+ """
+ Get all case stats. Caller of the DataAnalyser class. The function iterates datalist and
+ call get_case_stats to generate stats. Then get_case_summary is called to combine results.
+
+ Args:
+ key: dataset key
+ transform_list: option list of transforms before SegSummarizer
+
+ Returns:
+ A data statistics dictionary containing
+ "stats_summary" (summary statistics of the entire datasets). Within stats_summary
+ there are "image_stats" (summarizing info of shape, channel, spacing, and etc
+ using operations_summary), "image_foreground_stats" (info of the intensity for the
+ non-zero labeled voxels), and "label_stats" (info of the labels, pixel percentage,
+ image_intensity, and each individual label in a list)
+ "stats_by_cases" (List type value. Each element of the list is statistics of
+ an image-label info. Within each element, there are: "image" (value is the
+ path to an image), "label" (value is the path to the corresponding label), "image_stats"
+ (summarizing info of shape, channel, spacing, and etc using operations),
+ "image_foreground_stats" (similar to the previous one but one foreground image), and
+ "label_stats" (stats of the individual labels )
+
+ Notes:
+ Since the backend of the statistics computation are torch/numpy, nan/inf value
+ may be generated and carried over in the computation. In such cases, the output
+ dictionary will include .nan/.inf in the statistics.
+
+ """
+ summarizer = SegSummarizer(
+ self.image_key,
+ self.label_key,
+ average=self.average,
+ do_ccp=self.do_ccp,
+ hist_bins=self.hist_bins,
+ hist_range=self.hist_range,
+ histogram_only=self.histogram_only,
+ )
+ keys = list(filter(None, [self.image_key, self.label_key]))
+ if transform_list is None:
+ transform_list = [
+ LoadImaged(keys=keys),
+ EnsureChannelFirstd(keys=keys), # this creates label to be (1,H,W,D)
+ Orientationd(keys=keys, axcodes="RAS"),
+ EnsureTyped(keys=keys, data_type="tensor"),
+ Lambdad(keys=self.label_key, func=_argmax_if_multichannel) if self.label_key else None,
+ SqueezeDimd(keys=self.label_key, dim=0) if self.label_key else None,
+ ToDeviced(keys=keys, device=self.device),
+ ]
+ transform_list.append(summarizer)
+
+ transform = Compose(transforms=list(filter(None, transform_list)))
+
+ files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)
+ dataset = Dataset(data=files, transform=transform)
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=self.worker, collate_fn=no_collation)
+ result = {DataStatsKeys.SUMMARY: {}, DataStatsKeys.BY_CASE: []}
+
+ if not has_tqdm:
+ warnings.warn("tqdm is not installed. not displaying the caching progress.")
+
+ for batch_data in tqdm(dataloader) if has_tqdm else dataloader:
+ d = batch_data[0]
+ stats_by_cases = {
+ DataStatsKeys.BY_CASE_IMAGE_PATH: d[DataStatsKeys.BY_CASE_IMAGE_PATH],
+ DataStatsKeys.BY_CASE_LABEL_PATH: d[DataStatsKeys.BY_CASE_LABEL_PATH],
+ DataStatsKeys.IMAGE_STATS: d[DataStatsKeys.IMAGE_STATS],
+ }
+ if self.hist_bins != 0:
+ stats_by_cases.update({DataStatsKeys.IMAGE_HISTOGRAM: d[DataStatsKeys.IMAGE_HISTOGRAM]})
+
+ if self.label_key is not None:
+ stats_by_cases.update(
+ {
+ DataStatsKeys.FG_IMAGE_STATS: d[DataStatsKeys.FG_IMAGE_STATS],
+ DataStatsKeys.LABEL_STATS: d[DataStatsKeys.LABEL_STATS],
+ }
+ )
+ result[DataStatsKeys.BY_CASE].append(stats_by_cases)
+
+ result[DataStatsKeys.SUMMARY] = summarizer.summarize(result[DataStatsKeys.BY_CASE])
+
+ if not self._check_data_uniformity([ImageStatsKeys.SPACING], result):
+ logger.warning("data spacing is not completely uniform. MONAI transforms may provide unexpected result")
+
+ if self.output_path:
+ ConfigParser.export_config_file(result, self.output_path, fmt=self.fmt, default_flow_style=None)
+
+ # manually release the variable from cuda memory
+ del d[self.image_key]
+ if self.label_key and self.label_key in d:
+ del d[self.label_key]
+
+ if self.device.type == "cuda":
+ # release unreferenced tensors to mitigate OOM
+ # limitation: https://github.com/pytorch/pytorch/issues/12873#issuecomment-482916237
+ torch.cuda.empty_cache()
+
+ return result
diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py
new file mode 100644
index 00000000000..72ea557dc4e
--- /dev/null
+++ b/monai/apps/auto3dseg/ensemble_builder.py
@@ -0,0 +1,315 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Sequence
+from warnings import warn
+
+import numpy as np
+
+from monai.apps.auto3dseg.bundle_gen import BundleAlgo
+from monai.apps.utils import get_logger
+from monai.auto3dseg import concat_val_to_np
+from monai.bundle import ConfigParser
+from monai.transforms import MeanEnsemble, VoteEnsemble
+from monai.utils.enums import AlgoEnsembleKeys
+from monai.utils.misc import prob2class
+from monai.utils.module import look_up_option
+
+logger = get_logger(module_name=__name__)
+
+
+class AlgoEnsemble(ABC):
+ """
+ The base class of Ensemble methods
+ """
+
+ def __init__(self):
+ self.algos = []
+ self.mode = "mean"
+ self.infer_files = []
+ self.algo_ensemble = []
+
+ def set_algos(self, infer_algos):
+ """
+ Register model in the ensemble
+ """
+ self.algos = deepcopy(infer_algos)
+
+ def get_algo(self, identifier):
+ """
+ Get a model by identifier.
+
+ Args:
+ identifier: the name of the bundleAlgo
+ """
+ for algo in self.algos:
+ if identifier == algo[AlgoEnsembleKeys.ID]:
+ return algo
+
+ def get_algo_ensemble(self):
+ """
+ Get the algo ensemble after ranking or a empty list if ranking was not started.
+
+ Returns:
+ A list of Algo
+ """
+ return self.algo_ensemble
+
+ def set_infer_files(self, dataroot: str, data_list_file_path: str, data_key: str = "testing"):
+ """
+ Set the files to perform model inference.
+
+ Args:
+ dataroot: the path of the files
+ data_src_cfg_file: the data source file path
+ """
+ with open(data_list_file_path) as f:
+ datalist = json.load(f)
+
+ for d in datalist[data_key]:
+ self.infer_files.append({"image": os.path.join(dataroot, d["image"])})
+
+ def ensemble_pred(self, preds, sigmoid=True):
+ """
+ ensemble the results using either "mean" or "vote" method
+
+ Args:
+ preds: a list of probability prediction in Tensor-Like format.
+ sigmoid: use the sigmoid function to threshold probability one-hot map.
+
+ Returns:
+ a tensor which is the ensembled prediction.
+ """
+
+ if self.mode == "mean":
+ prob = MeanEnsemble()(preds)
+ return prob2class(prob, dim=0, keepdim=True, sigmoid=sigmoid)
+ elif self.mode == "vote":
+ classes = [prob2class(p, dim=0, keepdim=True, sigmoid=False) for p in preds]
+ return VoteEnsemble(num_classes=preds[0].shape[0])(classes)
+
+ def __call__(self, pred_param: Optional[Dict[str, Any]] = None):
+ """
+ Use the ensembled model to predict result.
+
+ Args:
+ pred_param: prediction parameter dictionary. The key has two groups: the first one will be consumed
+ in this function, and the second group will be passed to the `InferClass` to override the
+ parameters of the class functions.
+ The first group contains:
+ 'files_slices': a value type of `slice`. The files_slices will slice the infer_files and only
+ make prediction on the infer_files[file_slices].
+ 'mode': ensemble mode. Currently "mean" and "vote" (majority voting) schemes are supported.
+ 'sigmoid': use the sigmoid function (e.g. x>0.5) to convert the prediction probability map to
+ the label class prediction.
+
+ Returns:
+ A list of tensors.
+ """
+ if pred_param is None:
+ param = {}
+ else:
+ param = deepcopy(pred_param)
+
+ files = self.infer_files
+ if "files_slices" in param:
+ slices = param.pop("files_slices")
+ files = self.infer_files[slices]
+
+ if "mode" in param:
+ mode = param.pop("mode")
+ self.mode = look_up_option(mode, supported=["mean", "vote"])
+
+ sigmoid = True
+ if "sigmoid" in param:
+ sigmoid = param.pop("sigmoid")
+ sigmoid = look_up_option(sigmoid, supported=[True, False])
+
+ outputs = []
+ for i in range(len(files)):
+ print(i)
+ preds = []
+ infer_filename = self.infer_files[i]
+ for algo in self.algo_ensemble:
+ infer_instance = algo[AlgoEnsembleKeys.ALGO]
+ param.update({"files": [infer_filename]})
+ pred = infer_instance.predict(param)
+ preds.append(pred[0])
+ outputs.append(self.ensemble_pred(preds, sigmoid=sigmoid))
+ return outputs
+
+ @abstractmethod
+ def collect_algos(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class AlgoEnsembleBestN(AlgoEnsemble):
+ """
+ Ensemble method that select N model out of all using the models' best_metric scores
+
+ Args:
+ n_best: number of models to pick for ensemble (N).
+ """
+
+ def __init__(self, n_best: int = 5):
+
+ super().__init__()
+ self.n_best = n_best
+
+ def sort_score(self):
+ """
+ Sort the best_metrics
+ """
+ scores = concat_val_to_np(self.algos, [AlgoEnsembleKeys.SCORE])
+ return np.argsort(scores).tolist()
+
+ def collect_algos(self, n_best: int = -1):
+ """
+ Rank the algos by finding the top N (n_best) validation scores.
+ """
+
+ if n_best <= 0:
+ n_best = self.n_best
+
+ ranks = self.sort_score()
+ if len(ranks) < n_best:
+ raise ValueError("Number of available algos is less than user-defined N")
+
+ # get the indices that the rank is larger than N
+ indices = [i for (i, r) in enumerate(ranks) if r >= n_best]
+
+ # remove the found indices
+ indices = sorted(indices, reverse=True)
+
+ self.algo_ensemble = deepcopy(self.algos)
+ for idx in indices:
+ if idx < len(self.algo_ensemble):
+ self.algo_ensemble.pop(idx)
+
+
+class AlgoEnsembleBestByFold(AlgoEnsemble):
+ """
+ Ensemble method that select the best models that are the tops in each fold.
+
+ Args:
+ n_fold: number of cross-validation folds used in training
+ """
+
+ def __init__(self, n_fold: int = 5):
+
+ super().__init__()
+ self.n_fold = n_fold
+
+ def collect_algos(self):
+ """
+ Rank the algos by finding the best model in each cross-validation fold
+ """
+
+ self.algo_ensemble = []
+ for f_idx in range(self.n_fold):
+ best_score = -1.0
+ best_model: Optional[BundleAlgo] = None
+ for algo in self.algos:
+ identifier = algo[AlgoEnsembleKeys.ID].split("_")[-1]
+ try:
+ algo_id = int(identifier)
+ except ValueError as err:
+ raise ValueError(f"model identifier {identifier} is not number.") from err
+ if algo_id == f_idx and algo[AlgoEnsembleKeys.SCORE] > best_score:
+ best_model = algo
+ self.algo_ensemble.append(best_model)
+
+
+class AlgoEnsembleBuilder:
+ """
+ Build ensemble workflow from configs and arguments.
+
+ Args:
+ history: a collection of trained bundleAlgo algorithms.
+ data_src_cfg_filename: filename of the data source.
+
+ Examples:
+
+ .. code-block:: python
+
+ builder = AlgoEnsembleBuilder(history, data_src_cfg)
+ builder.set_ensemble_method(BundleAlgoEnsembleBestN(3))
+ ensemble = builder.get_ensemble()
+
+ result = ensemble.predict()
+ """
+
+ def __init__(self, history: Sequence[Dict], data_src_cfg_filename: Optional[str] = None):
+ self.infer_algos: List[Dict[AlgoEnsembleKeys, Any]] = []
+ self.ensemble: AlgoEnsemble
+ self.data_src_cfg = ConfigParser(globals=False)
+
+ if data_src_cfg_filename is not None and os.path.exists(str(data_src_cfg_filename)):
+ self.data_src_cfg.read_config(data_src_cfg_filename)
+
+ for h in history:
+ # load inference_config_paths
+ # raise warning/error if not found
+ if len(h) > 1:
+ raise ValueError(f"{h} should only contain one set of genAlgo key-value")
+
+ name = list(h.keys())[0]
+ gen_algo = h[name]
+ best_metric = gen_algo.get_score()
+ algo_path = gen_algo.output_path
+ infer_path = os.path.join(algo_path, "scripts", "infer.py")
+
+ if not os.path.isdir(algo_path):
+ warn(f"{gen_algo.output_path} is not a directory. Please check the path.")
+
+ if not os.path.isfile(infer_path):
+ warn(f"{infer_path} is not found. Please check the path.")
+
+ self.add_inferer(name, gen_algo, best_metric)
+
+ def add_inferer(self, identifier: str, gen_algo: BundleAlgo, best_metric: Optional[float] = None):
+ """
+ Add model inferer to the builder.
+
+ Args:
+ identifier: name of the bundleAlgo.
+ gen_algo: a trained BundleAlgo model object.
+ best_metric: the best metric in validation of the trained model.
+ """
+
+ if best_metric is None:
+ raise ValueError("Feature to re-validate is to be implemented")
+
+ algo = {AlgoEnsembleKeys.ID: identifier, AlgoEnsembleKeys.ALGO: gen_algo, AlgoEnsembleKeys.SCORE: best_metric}
+ self.infer_algos.append(algo)
+
+ def set_ensemble_method(self, ensemble: AlgoEnsemble, *args, **kwargs):
+ """
+ Set the ensemble method.
+
+ Args:
+ ensemble: the AlgoEnsemble to build.
+ """
+
+ ensemble.set_algos(self.infer_algos)
+ ensemble.collect_algos(*args, **kwargs)
+ ensemble.set_infer_files(self.data_src_cfg["dataroot"], self.data_src_cfg["datalist"])
+
+ self.ensemble = ensemble
+
+ def get_ensemble(self):
+ """Get the ensemble"""
+
+ return self.ensemble
diff --git a/monai/apps/auto3dseg/hpo_gen.py b/monai/apps/auto3dseg/hpo_gen.py
new file mode 100644
index 00000000000..f9f709053be
--- /dev/null
+++ b/monai/apps/auto3dseg/hpo_gen.py
@@ -0,0 +1,402 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from abc import abstractmethod
+from copy import deepcopy
+from warnings import warn
+
+from monai.apps.auto3dseg.bundle_gen import BundleAlgo
+from monai.apps.utils import get_logger
+from monai.auto3dseg import Algo, AlgoGen, algo_from_pickle, algo_to_pickle
+from monai.bundle.config_parser import ConfigParser
+from monai.utils import optional_import
+
+nni, has_nni = optional_import("nni")
+optuna, has_optuna = optional_import("optuna")
+logger = get_logger(module_name=__name__)
+
+__all__ = ["HPOGen", "NNIGen", "OptunaGen"]
+
+
+class HPOGen(AlgoGen):
+ """
+ The base class for hyperparameter optimization (HPO) interfaces to generate algos in the Auto3Dseg pipeline.
+ The auto-generated algos are saved at their ``output_path`` on the disk. The files in the ``output_path``
+ may contain scripts that define the algo, configuration files, and pickle files that save the internal states
+ of the algo before/after the training. Compared to the BundleGen class, HPOGen generates Algo on-the-fly, so
+ training and algo generation may be executed alternatively and take a long time to finish the generation process.
+
+ """
+
+ @abstractmethod
+ def get_hyperparameters(self):
+ """Get the hyperparameter from HPO."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def update_params(self, *args, **kwargs):
+ """Update Algo parameters according to the hyperparameters to be evaluated."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_score(self):
+ """Report the evaluated results to HPO."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def run_algo(self, *args, **kwargs):
+ """Interface for launch the training given the fetched hyperparameters."""
+ raise NotImplementedError
+
+
+class NNIGen(HPOGen):
+ """
+ Generate algorithms for the NNI to automate hyperparameter tuning. The module has two major interfaces:
+ ``__init__`` which prints out how to set up the NNI, and a trialCommand function ``run_algo`` for the NNI library to
+ start the trial of the algo. More about trialCommand function can be found in ``trail code`` section in NNI webpage
+ https://nni.readthedocs.io/en/latest/tutorials/hpo_quickstart_pytorch/main.html .
+
+ Args:
+ algo: an Algo object (e.g. BundleAlgo) with defined methods: ``get_output_path`` and train
+ and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.
+ params: a set of parameter to override the algo if override is supported by Algo subclass.
+
+ Examples::
+
+ The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.
+ ├── algorithm_templates
+ │ └── unet
+ ├── unet_0
+ │ ├── algo_object.pkl
+ │ ├── configs
+ │ └── scripts
+ ├── unet_0_learning_rate_0.01
+ │ ├── algo_object.pkl
+ │ ├── configs
+ │ ├── model_fold0
+ │ └── scripts
+ └── unet_0_learning_rate_0.1
+ ├── algo_object.pkl
+ ├── configs
+ ├── model_fold0
+ └── scripts
+
+ Notes:
+ The NNIGen will prepare the algorithms in a folder and suggest a command to replace trialCommand in the experiment
+ config. However, NNIGen will not trigger NNI. User needs to write their NNI experiment configs, and then run the
+ NNI command manually.
+ """
+
+ def __init__(self, algo=None, params=None):
+ self.algo: Algo
+ self.hint = ""
+ self.obj_filename = ""
+
+ if algo is not None:
+ if isinstance(algo, BundleAlgo):
+ if params is None:
+ self.algo = algo
+ else:
+ self.algo = deepcopy(algo)
+ name = os.path.basename(algo.get_output_path()) + "_override"
+ output_folder = os.path.dirname(algo.get_output_path())
+
+ params.update({"fill_with_datastats": False}) # just copy, not using datastats to fill
+ self.algo.export_to_disk(output_folder, name, **params)
+ else:
+ self.algo = algo
+
+ if isinstance(algo, BundleAlgo):
+ self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)
+ self.print_bundle_algo_instruction()
+ else:
+ self.obj_filename = algo_to_pickle(self.algo)
+ # nni instruction unknown
+
+ def get_obj_filename(self):
+ """Return the filename of the dumped pickle algo object."""
+ return self.obj_filename
+
+ def print_bundle_algo_instruction(self):
+ """
+ Print how to write the trial commands for Bundle Algo.
+ """
+ hint = "python -m monai.apps.auto3dseg NNIGen run_algo "
+ logger.info("=" * 140)
+ logger.info("If NNI will run in your local env: ")
+ logger.info("1. Add the following line to the trialCommand in your NNI config: ")
+ logger.info(f"{hint} {self.obj_filename} {{result_dir}}")
+ logger.info("-" * 140)
+ logger.info("If NNI will run in a remote env: ")
+ logger.info(
+ f"1. Copy the algorithm_templates folder {self.algo.template_path} to remote {{remote_algorithm_templates_dir}}"
+ )
+ logger.info(f"2. Copy the older {self.algo.get_output_path()} to the remote machine {{remote_algo_dir}}")
+ logger.info("Then add the following line to the trialCommand in your NNI config: ")
+ logger.info(f"{hint} {{remote_algo_dir}} {{result_dir}} {{remote_algorithm_templates_dir}}")
+ logger.info("=" * 140)
+
+ def get_hyperparameters(self):
+ """
+ Get parameter for next round of training from NNI server.
+ """
+ if has_nni:
+ return nni.get_next_parameter()
+ warn("NNI is not detected. The code will continue to run without NNI.")
+ return {}
+
+ def update_params(self, params: dict):
+ """
+ Translate the parameter from monai bundle to meet NNI requirements.
+
+ Args:
+ params: a dict of parameters.
+ """
+ self.params = params
+
+ def get_task_id(self):
+ """
+ Get the identifier of the current experiment. In the format of listing the searching parameter name and values
+ connected by underscore in the file name.
+ """
+ return "".join(f"_{k}_{v}" for k, v in self.params.items()) or "_None"
+
+ def generate(self, output_folder: str = ".") -> None:
+ """
+ Generate the record for each Algo. If it is a BundleAlgo, it will generate the config files.
+
+ Args:
+ output_folder: the directory nni will save the results to.
+ """
+ task_id = self.get_task_id()
+ task_prefix = os.path.basename(self.algo.get_output_path())
+ write_path = os.path.join(output_folder, task_prefix + task_id)
+ self.obj_filename = os.path.join(write_path, "algo_object.pkl")
+
+ if isinstance(self.algo, BundleAlgo):
+ self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False)
+ else:
+
+ ConfigParser.export_config_file(self.params, write_path)
+ logger.info(write_path)
+
+ def set_score(self, acc):
+ """
+ Report the acc to NNI server.
+ """
+ if has_nni:
+ nni.report_final_result(acc)
+ else:
+ warn("NNI is not detected. The code will continue to run without NNI.")
+
+ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path=None) -> None:
+ """
+ The python interface for NNI to run.
+
+ Args:
+ obj_filename: the pickle-exported Algo object.
+ output_folder: the root path of the algorithms templates.
+ template_path: the algorithm_template. It must contain algo.py in the follow path:
+ ``{algorithm_templates_dir}/{network}/scripts/algo.py``
+ """
+ if not os.path.isfile(obj_filename):
+ raise ValueError(f"{obj_filename} is not found")
+
+ self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)
+
+ if isinstance(self.algo, BundleAlgo): # algo's template path needs override
+ self.algo.template_path = algo_meta_data["template_path"]
+
+ # step 1 sample hyperparams
+ params = self.get_hyperparameters()
+ # step 2 set the update params for the algo to run in the next trial
+ self.update_params(params)
+ # step 3 generate the folder to save checkpoints and train
+ self.generate(output_folder)
+ self.algo.train(self.params)
+ # step 4 report validation acc to controller
+ acc = self.algo.get_score()
+ if isinstance(self.algo, BundleAlgo):
+ algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
+ else:
+ algo_to_pickle(self.algo, best_metrics=acc)
+ self.set_score(acc)
+
+
+class OptunaGen(HPOGen):
+ """
+ Generate algorithms for the Optuna to automate hyperparameter tuning. Please refer to NNI and Optuna
+ (https://optuna.readthedocs.io/en/stable/) for more information. Optuna has different running scheme
+ compared to NNI. The hyperparameter samples come from a trial object (trial.suggest...) created by Optuna,
+ so OptunaGen needs to accept this trial object as input. Meanwhile, Optuna calls OptunaGen,
+ thus OptunaGen.__call__() should return the accuracy. Use functools.partial to wrap OptunaGen
+ for addition input arguments.
+
+ Args:
+ algo: an Algo object (e.g. BundleAlgo). The object must at least define two methods: get_output_path and train
+ and supports saving to and loading from pickle files via ``algo_from_pickle`` and ``algo_to_pickle``.
+ params: a set of parameter to override the algo if override is supported by Algo subclass.
+
+ Examples::
+
+ The experiment will keep generating new folders to save the model checkpoints, scripts, and configs if available.
+ ├── algorithm_templates
+ │ └── unet
+ ├── unet_0
+ │ ├── algo_object.pkl
+ │ ├── configs
+ │ └── scripts
+ ├── unet_0_learning_rate_0.01
+ │ ├── algo_object.pkl
+ │ ├── configs
+ │ ├── model_fold0
+ │ └── scripts
+ └── unet_0_learning_rate_0.1
+ ├── algo_object.pkl
+ ├── configs
+ ├── model_fold0
+ └── scripts
+
+ Notes:
+ Different from NNI and NNIGen, OptunaGen and Optuna can be ran within the Python process.
+
+ """
+
+ def __init__(self, algo=None, params=None):
+ self.algo: Algo
+ self.obj_filename = ""
+
+ if algo is not None:
+ if isinstance(algo, BundleAlgo):
+ if params is None:
+ self.algo = algo
+ else:
+ self.algo = deepcopy(algo)
+ name = os.path.basename(algo.get_output_path()) + "_override"
+ output_folder = os.path.dirname(algo.get_output_path())
+
+ params.update({"fill_with_datastats": False}) # just copy, not using datastats to fill
+ self.algo.export_to_disk(output_folder, name, **params)
+ else:
+ self.algo = algo
+
+ if isinstance(algo, BundleAlgo):
+ self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)
+ else:
+ self.obj_filename = algo_to_pickle(self.algo)
+ # nni instruction unknown
+
+ def get_obj_filename(self):
+ """Return the dumped pickle object of algo."""
+ return self.obj_filename
+
+ def get_hyperparameters(self):
+ """
+ Get parameter for next round of training from optuna trial object.
+ This function requires user rewrite during usage for different search space.
+ """
+ if has_optuna:
+ logger.info("Please rewrite this code by creating a child class")
+ return {"learning_rate": self.trial.suggest_float("learning_rate", 0.0001, 0.1)}
+ else:
+ warn("Optuna is not detected. The code will continue to run without Optuna.")
+ return {}
+
+ def set_score(self, acc):
+ """Set the accuracy score"""
+ self.acc = acc
+
+ def set_trial(self, trial):
+ """Set the Optuna trial"""
+ self.trial = trial
+
+ def __call__(self, trial, obj_filename: str, output_folder: str = ".", template_path=None):
+ """
+ Callabe that Optuna will use to optimize the hyper-parameters
+
+ Args:
+ obj_filename: the pickle-exported Algo object.
+ output_folder: the root path of the algorithms templates.
+ template_path: the algorithm_template. It must contain algo.py in the follow path:
+ ``{algorithm_templates_dir}/{network}/scripts/algo.py``
+ """
+ self.set_trial(trial)
+ self.run_algo(obj_filename, output_folder, template_path)
+ return self.acc
+
+ def update_params(self, params: dict):
+ """
+ Translate the parameter from monai bundle.
+
+ Args:
+ params: a dict of parameters.
+ """
+ self.params = params
+
+ def get_task_id(self):
+ """
+ Get the identifier of the current experiment. In the format of listing the searching parameter name and values
+ connected by underscore in the file name.
+ """
+ return "".join(f"_{k}_{v}" for k, v in self.params.items()) or "_None"
+
+ def generate(self, output_folder: str = ".") -> None:
+ """
+ Generate the record for each Algo. If it is a BundleAlgo, it will generate the config files.
+
+ Args:
+ output_folder: the directory nni will save the results to.
+ """
+ task_id = self.get_task_id()
+ task_prefix = os.path.basename(self.algo.get_output_path())
+ write_path = os.path.join(output_folder, task_prefix + task_id)
+ self.obj_filename = os.path.join(write_path, "algo_object.pkl")
+
+ if isinstance(self.algo, BundleAlgo):
+ self.algo.export_to_disk(output_folder, task_prefix + task_id, fill_with_datastats=False)
+ else:
+
+ ConfigParser.export_config_file(self.params, write_path)
+ logger.info(write_path)
+
+ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path=None) -> None:
+ """
+ The python interface for NNI to run.
+
+ Args:
+ obj_filename: the pickle-exported Algo object.
+ output_folder: the root path of the algorithms templates.
+ template_path: the algorithm_template. It must contain algo.py in the follow path:
+ ``{algorithm_templates_dir}/{network}/scripts/algo.py``
+ """
+ if not os.path.isfile(obj_filename):
+ raise ValueError(f"{obj_filename} is not found")
+
+ self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)
+
+ if isinstance(self.algo, BundleAlgo): # algo's template path needs override
+ self.algo.template_path = algo_meta_data["template_path"]
+
+ # step 1 sample hyperparams
+ params = self.get_hyperparameters()
+ # step 2 set the update params for the algo to run in the next trial
+ self.update_params(params)
+ # step 3 generate the folder to save checkpoints and train
+ self.generate(output_folder)
+ self.algo.train(self.params)
+ # step 4 report validation acc to controller
+ acc = self.algo.get_score()
+ if isinstance(self.algo, BundleAlgo):
+ algo_to_pickle(self.algo, template_path=self.algo.template_path, best_metrics=acc)
+ else:
+ algo_to_pickle(self.algo, best_metrics=acc)
+ self.set_score(acc)
diff --git a/monai/apps/auto3dseg/utils.py b/monai/apps/auto3dseg/utils.py
new file mode 100644
index 00000000000..82efb309643
--- /dev/null
+++ b/monai/apps/auto3dseg/utils.py
@@ -0,0 +1,67 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Dict, List, Optional
+
+from monai.apps.auto3dseg.bundle_gen import BundleAlgo
+from monai.auto3dseg import algo_from_pickle, algo_to_pickle
+
+
+def import_bundle_algo_history(
+ output_folder: str = ".", template_path: Optional[str] = None, only_trained: bool = True
+) -> List:
+ """
+ import the history of the bundleAlgo object with their names/identifiers
+
+ Args:
+ output_folder: the root path of the algorithms templates.
+ template_path: the algorithm_template. It must contain algo.py in the follow path:
+ ``{algorithm_templates_dir}/{network}/scripts/algo.py``.
+ only_trained: only read the algo history if the algo is trained.
+ """
+
+ history = []
+
+ for name in os.listdir(output_folder):
+ write_path = os.path.join(output_folder, name)
+
+ if not os.path.isdir(write_path):
+ continue
+
+ obj_filename = os.path.join(write_path, "algo_object.pkl")
+ if not os.path.isfile(obj_filename): # saved mode pkl
+ continue
+
+ algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)
+
+ if isinstance(algo, BundleAlgo): # algo's template path needs override
+ algo.template_path = algo_meta_data["template_path"]
+
+ if only_trained:
+ if "best_metrics" in algo_meta_data:
+ history.append({name: algo})
+ else:
+ history.append({name: algo})
+
+ return history
+
+
+def export_bundle_algo_history(history: List[Dict[str, BundleAlgo]]):
+ """
+ Save all the BundleAlgo in the history to algo_object.pkl in each individual folder
+
+ Args:
+ history: a List of Bundle. Typically, the history can be obtained from BundleGen get_history method
+ """
+ for task in history:
+ for _, algo in task.items():
+ algo_to_pickle(algo, template_path=algo.template_path)
diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py
index 75fda6da346..e5992711dad 100644
--- a/monai/apps/datasets.py
+++ b/monai/apps/datasets.py
@@ -53,15 +53,15 @@ class MedNISTDataset(Randomizable, CacheDataset):
if expected file already exists, skip downloading even set it to True.
user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory.
seed: random seed to randomly split training, validation and test datasets, default is 0.
- val_frac: percentage of of validation fraction in the whole dataset, default is 0.1.
- test_frac: percentage of of test fraction in the whole dataset, default is 0.1.
+ val_frac: percentage of validation fraction in the whole dataset, default is 0.1.
+ test_frac: percentage of test fraction in the whole dataset, default is 0.1.
cache_num: number of items to be cached. Default is `sys.maxsize`.
will take the minimum of (cache_num, data_length x cache_rate, data_length).
cache_rate: percentage of cached data in total, default is 1.0 (cache all).
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
If num_workers is None then the number returned by os.cpu_count() is used.
- If a value less than 1 is speficied, 1 will be used instead.
+ If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
@@ -205,7 +205,7 @@ class DecathlonDataset(Randomizable, CacheDataset):
download: whether to download and extract the Decathlon from resource link, default is False.
if expected file already exists, skip downloading even set it to True.
user can manually copy tar file or dataset folder to the root directory.
- val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
+ val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
note to set same seed for `training` and `validation` sections.
cache_num: number of items to be cached. Default is `sys.maxsize`.
@@ -214,7 +214,7 @@ class DecathlonDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
If num_workers is None then the number returned by os.cpu_count() is used.
- If a value less than 1 is speficied, 1 will be used instead.
+ If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
@@ -395,7 +395,7 @@ class TciaDataset(Randomizable, CacheDataset):
and generate items for training, validation or test.
The Highdicom library is used to load dicom data with modality "SEG", but only a part of collections are
- supoorted, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability"
+ supported, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability"
and "PROSTATEx". Therefore, if "seg" is included in `keys` of the `LoadImaged` transform and loading some
other collections, errors may be raised. For supported collections, the original "SEG" information may not
always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use
@@ -431,7 +431,7 @@ class TciaDataset(Randomizable, CacheDataset):
specific_tags: tags that will be loaded for "SEG" series. This argument will be used in
`monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),
(0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].
- val_frac: percentage of of validation fraction in the whole dataset, default is 0.2.
+ val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
note to set same seed for `training` and `validation` sections.
cache_num: number of items to be cached. Default is `sys.maxsize`.
@@ -440,7 +440,7 @@ class TciaDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
If num_workers is None then the number returned by os.cpu_count() is used.
- If a value less than 1 is speficied, 1 will be used instead.
+ If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
@@ -701,7 +701,7 @@ def __init__(self, dataset_cls, nfolds: int = 5, seed: int = 0, **dataset_params
def get_dataset(self, folds: Union[Sequence[int], int], **dataset_params):
"""
- Generate dataset based on the specified fold indice in the cross validation group.
+ Generate dataset based on the specified fold indices in the cross validation group.
Args:
folds: index of folds for training or validation, if a list of values, concatenate the data.
diff --git a/monai/apps/deepedit/interaction.py b/monai/apps/deepedit/interaction.py
index 04fabec06d1..dce81f095e8 100644
--- a/monai/apps/deepedit/interaction.py
+++ b/monai/apps/deepedit/interaction.py
@@ -37,6 +37,7 @@ class Interaction:
train: True for training mode or False for evaluation mode
click_probability_key: key to click/interaction probability
label_names: Dict of label names
+ max_interactions: maximum number of interactions per iteration
"""
def __init__(
@@ -44,8 +45,9 @@ def __init__(
deepgrow_probability: float,
transforms: Union[Sequence[Callable], Callable],
train: bool,
- label_names: Dict[str, int],
+ label_names: Union[None, Dict[str, int]] = None,
click_probability_key: str = "probability",
+ max_interactions: int = 1,
) -> None:
self.deepgrow_probability = deepgrow_probability
@@ -53,40 +55,38 @@ def __init__(
self.train = train
self.label_names = label_names
self.click_probability_key = click_probability_key
+ self.max_interactions = max_interactions
def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]):
-
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
if np.random.choice([True, False], p=[self.deepgrow_probability, 1 - self.deepgrow_probability]):
-
- # Run the inner loop only once
- inputs, _ = engine.prepare_batch(batchdata)
- inputs = inputs.to(engine.state.device)
-
- engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)
-
- engine.network.eval()
- with torch.no_grad():
- if engine.amp:
- with torch.cuda.amp.autocast():
+ for j in range(self.max_interactions):
+ inputs, _ = engine.prepare_batch(batchdata)
+ inputs = inputs.to(engine.state.device)
+
+ engine.fire_event(IterationEvents.INNER_ITERATION_STARTED)
+ engine.network.eval()
+
+ with torch.no_grad():
+ if engine.amp:
+ with torch.cuda.amp.autocast():
+ predictions = engine.inferer(inputs, engine.network)
+ else:
predictions = engine.inferer(inputs, engine.network)
- else:
- predictions = engine.inferer(inputs, engine.network)
- batchdata.update({CommonKeys.PRED: predictions})
-
- # decollate/collate batchdata to execute click transforms
- batchdata_list = decollate_batch(batchdata, detach=True)
-
- for i in range(len(batchdata_list)):
- batchdata_list[i][self.click_probability_key] = 1.0
- batchdata_list[i] = self.transforms(batchdata_list[i])
-
- batchdata = list_data_collate(batchdata_list)
-
- engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)
-
+ batchdata.update({CommonKeys.PRED: predictions})
+
+ # decollate/collate batchdata to execute click transforms
+ batchdata_list = decollate_batch(batchdata, detach=True)
+ for i in range(len(batchdata_list)):
+ batchdata_list[i][self.click_probability_key] = (
+ (1.0 - ((1.0 / self.max_interactions) * j)) if self.train else 1.0
+ )
+ batchdata_list[i] = self.transforms(batchdata_list[i])
+
+ batchdata = list_data_collate(batchdata_list)
+ engine.fire_event(IterationEvents.INNER_ITERATION_COMPLETED)
else:
# zero out input guidance channels
batchdata_list = decollate_batch(batchdata, detach=True)
@@ -96,5 +96,4 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
# first item in batch only
engine.state.batch = batchdata
-
return engine._iteration(engine, batchdata)
diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py
index 5778127bf38..76b9e18cc70 100644
--- a/monai/apps/deepedit/transforms.py
+++ b/monai/apps/deepedit/transforms.py
@@ -13,7 +13,7 @@
import logging
import random
import warnings
-from typing import Dict, Hashable, Mapping, Optional
+from typing import Dict, Hashable, List, Mapping, Optional
import numpy as np
import torch
@@ -28,7 +28,6 @@
logger = logging.getLogger(__name__)
-
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
@@ -98,22 +97,22 @@ def __init__(self, keys: KeysCollection, label_names=None, allow_missing_keys: b
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d: Dict = dict(data)
for key in self.key_iterator(d):
- if key == "label":
- # Dictionary containing new label numbers
- new_label_names = {}
- label = np.zeros(d[key].shape)
- # Making sure the range values and number of labels are the same
- for idx, (key_label, val_label) in enumerate(self.label_names.items(), start=1):
- if key_label != "background":
- new_label_names[key_label] = idx
- label[d[key] == val_label] = idx
- if key_label == "background":
- new_label_names["background"] = 0
-
- d["label_names"] = new_label_names
- d[key] = label
+ # Dictionary containing new label numbers
+ new_label_names = {}
+ label = np.zeros(d[key].shape)
+ # Making sure the range values and number of labels are the same
+ for idx, (key_label, val_label) in enumerate(self.label_names.items(), start=1):
+ if key_label != "background":
+ new_label_names[key_label] = idx
+ label[d[key] == val_label] = idx
+ if key_label == "background":
+ new_label_names["background"] = 0
+
+ d["label_names"] = new_label_names
+ if isinstance(d[key], MetaTensor):
+ d[key].array = label
else:
- warnings.warn("This transform only applies to the label")
+ d[key] = label
return d
@@ -500,13 +499,14 @@ def __init__(
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
- self.guidance = guidance
+ self.guidance_key = guidance
self.discrepancy = discrepancy
self.probability = probability
self._will_interact = None
self.is_pos = None
self.is_other = None
self.default_guidance = None
+ self.guidance: Dict[str, List[List[int]]] = {}
def randomize(self, data=None):
probability = data[self.probability]
@@ -559,31 +559,30 @@ def add_guidance(self, guidance, discrepancy, label_names, labels):
tmp_label = np.copy(labels)
tmp_label[tmp_label != label_names[key_label]] = 0
tmp_label = (tmp_label > 0.5).astype(np.float32)
- self.tmp_guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))
+ self.guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))
else:
tmp_label = np.copy(labels)
tmp_label[tmp_label != label_names[key_label]] = 1
tmp_label = 1 - tmp_label
- self.tmp_guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))
+ self.guidance[key_label].append(self.find_guidance(discrepancy[1] * tmp_label))
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d: Dict = dict(data)
- guidance = d[self.guidance]
+ guidance = d[self.guidance_key]
discrepancy = d[self.discrepancy]
self.randomize(data)
if self._will_interact:
# Convert all guidance to lists so new guidance can be easily appended
- self.tmp_guidance = {}
for key_label in d["label_names"].keys():
tmp_gui = guidance[key_label]
tmp_gui = tmp_gui.tolist() if isinstance(tmp_gui, np.ndarray) else tmp_gui
tmp_gui = json.loads(tmp_gui) if isinstance(tmp_gui, str) else tmp_gui
- self.tmp_guidance[key_label] = [j for j in tmp_gui if -1 not in j]
+ self.guidance[key_label] = [j for j in tmp_gui if -1 not in j]
# Add guidance according to discrepancy
for key_label in d["label_names"].keys():
# Add guidance based on discrepancy
- self.add_guidance(self.tmp_guidance[key_label], discrepancy[key_label], d["label_names"], d["label"])
+ self.add_guidance(self.guidance[key_label], discrepancy[key_label], d["label_names"], d["label"])
# Checking the number of clicks
num_clicks = random.randint(1, 10)
@@ -595,12 +594,12 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
pass
else:
keep_guidance.append(aux_label)
- counter = counter + len(self.tmp_guidance[aux_label])
+ counter = counter + len(self.guidance[aux_label])
# If collected clicks is bigger than max clicks, discard the others
if counter >= num_clicks:
for key_label in d["label_names"].keys():
if key_label not in keep_guidance:
- self.tmp_guidance[key_label] = []
+ self.guidance[key_label] = []
logger.info(f"Number of simulated clicks: {counter}")
break
@@ -608,7 +607,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
if len(keep_guidance) == len(d["label_names"].keys()):
logger.info(f"Number of simulated clicks: {counter}")
break
-
+ d[self.guidance_key] = self.guidance # Update the guidance
return d
@@ -626,7 +625,7 @@ class AddGuidanceFromPointsDeepEditd(Transform):
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the metadata is a dictionary object which contains: filename, original_shape, etc.
if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.
- meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the metadata according
+ meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the metadata according
to the key data, default is `meta_dict`, the metadata is a dictionary object.
For example, to handle key `image`, read/write affine matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
diff --git a/monai/apps/deepgrow/dataset.py b/monai/apps/deepgrow/dataset.py
index 3de8e9cc135..3e4aa296ab7 100644
--- a/monai/apps/deepgrow/dataset.py
+++ b/monai/apps/deepgrow/dataset.py
@@ -15,9 +15,8 @@
import numpy as np
-from monai.transforms import Compose, EnsureChannelFirstd, FromMetaTensord, LoadImaged, Orientationd, Spacingd, ToNumpyd
+from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, Orientationd, Spacingd, SqueezeDimd
from monai.utils import GridSampleMode
-from monai.utils.enums import PostFix
def create_dataset(
@@ -85,12 +84,12 @@ def create_dataset(
transforms = _default_transforms(image_key, label_key, pixdim) if transforms is None else transforms
new_datalist = []
- for idx in range(len(datalist)):
+ for idx, item in enumerate(datalist):
if limit and idx >= limit:
break
- image = datalist[idx][image_key]
- label = datalist[idx].get(label_key, None)
+ image = item[image_key]
+ label = item.get(label_key, None)
if base_dir:
image = os.path.join(base_dir, image)
label = os.path.join(base_dir, label) if label else None
@@ -100,19 +99,29 @@ def create_dataset(
logging.info(f"Image: {image}; Label: {label if label else None}")
data = transforms({image_key: image, label_key: label})
+
+ vol_image = data[image_key]
+ vol_label = data.get(label_key)
+ logging.info(f"Image (transform): {vol_image.shape}; Label: {None if vol_label is None else vol_label.shape}")
+
+ vol_image = np.moveaxis(vol_image, -1, 0)
+ if vol_label is not None:
+ vol_label = np.moveaxis(vol_label, -1, 0)
+ logging.info(f"Image (final): {vol_image.shape}; Label: {None if vol_label is None else vol_label.shape}")
+
if dimension == 2:
data = _save_data_2d(
vol_idx=idx,
- vol_image=data[image_key],
- vol_label=data[label_key],
+ vol_image=vol_image,
+ vol_label=vol_label,
dataset_dir=output_dir,
relative_path=relative_path,
)
else:
data = _save_data_3d(
vol_idx=idx,
- vol_image=data[image_key],
- vol_label=data[label_key],
+ vol_image=vol_image,
+ vol_label=vol_label,
dataset_dir=output_dir,
relative_path=relative_path,
)
@@ -129,25 +138,13 @@ def _default_transforms(image_key, label_key, pixdim):
EnsureChannelFirstd(keys=keys),
Orientationd(keys=keys, axcodes="RAS"),
Spacingd(keys=keys, pixdim=pixdim, mode=mode),
- FromMetaTensord(keys=keys),
- ToNumpyd(keys=keys + [PostFix.meta(k) for k in keys]),
+ SqueezeDimd(keys=keys),
]
)
def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
data_list = []
- if len(vol_image.shape) == 4:
- logging.info(
- "4D-Image, pick only first series; Image: {}; Label: {}".format(
- vol_image.shape, vol_label.shape if vol_label is not None else None
- )
- )
- vol_image = vol_image[0]
- vol_image = np.moveaxis(vol_image, -1, 0)
- if vol_label is not None:
- vol_label = vol_label[0]
- vol_label = np.moveaxis(vol_label, -1, 0)
image_count = 0
label_count = 0
@@ -216,15 +213,6 @@ def _save_data_2d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
def _save_data_3d(vol_idx, vol_image, vol_label, dataset_dir, relative_path):
data_list = []
- if len(vol_image.shape) == 4:
- logging.info(
- "4D-Image, pick only first series; Image: {}; Label: {}".format(
- vol_image.shape, vol_label.shape if vol_label is not None else None
- )
- )
- vol_image = vol_image[0]
- vol_image = np.moveaxis(vol_image, -1, 0)
-
image_count = 0
label_count = 0
unique_labels_count = 0
diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py
index 537a21a9665..9340a80f7a7 100644
--- a/monai/apps/deepgrow/transforms.py
+++ b/monai/apps/deepgrow/transforms.py
@@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union
+from typing import Callable, Dict, Hashable, Optional, Sequence, Union
import numpy as np
import torch
@@ -656,8 +656,8 @@ def bounding_box(self, points, img_shape):
def __call__(self, data):
d: Dict = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
return d
guidance = d[self.guidance]
@@ -720,7 +720,7 @@ class ResizeGuidanced(Transform):
for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
the metadata is a dictionary object which contains: filename, original_shape, etc.
if None, will try to construct meta_keys by `{ref_image}_{meta_key_postfix}`.
- meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to to fetch the metadata according
+ meta_key_postfix: if meta_key is None, use `{ref_image}_{meta_key_postfix}` to fetch the metadata according
to the key data, default is `meta_dict`, the metadata is a dictionary object.
For example, to handle key `image`, read/write affine matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
diff --git a/monai/apps/detection/metrics/coco.py b/monai/apps/detection/metrics/coco.py
index 7c6b02fe4fd..8dba3fa7da5 100644
--- a/monai/apps/detection/metrics/coco.py
+++ b/monai/apps/detection/metrics/coco.py
@@ -56,7 +56,6 @@
# The views and conclusions contained in the software and documentation are those
# of the authors and should not be interpreted as representing official policies,
# either expressed or implied, of the FreeBSD Project.
-
"""
This script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/coco.py
The changes include 1) code reformatting, 2) docstrings.
@@ -405,7 +404,7 @@ def _compute_statistics(
Args:
results_list (List[Dict[int, Dict[str, np.ndarray]]]): list with result s per image (in list)
- per cateory (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
+ per category (dict). Inner Dict contains multiple results obtained by :func:`box_matching_batch`.
- `dtMatches`: matched detections [T, D], where T = number of
thresholds, D = number of detections
diff --git a/monai/apps/detection/metrics/matching.py b/monai/apps/detection/metrics/matching.py
index 6df026bf548..37e6e2fa069 100644
--- a/monai/apps/detection/metrics/matching.py
+++ b/monai/apps/detection/metrics/matching.py
@@ -56,7 +56,6 @@
# The views and conclusions contained in the software and documentation are those
# of the authors and should not be interpreted as representing official policies,
# either expressed or implied, of the FreeBSD Project.
-
"""
This script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/matching.py
The changes include 1) code reformatting, 2) docstrings,
@@ -162,7 +161,7 @@ def matching_batch(
result = {} # dict contains results for each class in one image
for c in img_classes:
pred_mask = pclasses == c # bool mask predictions with current class
- gt_mask = gclasses == c # nool mask ground trtuh with current class
+ gt_mask = gclasses == c # bool mask ground truth with current class
if not np.any(gt_mask): # no ground truth
result[c] = _matching_no_gt(
diff --git a/monai/apps/detection/networks/retinanet_detector.py b/monai/apps/detection/networks/retinanet_detector.py
index fd270ee094a..4c6f165439a 100644
--- a/monai/apps/detection/networks/retinanet_detector.py
+++ b/monai/apps/detection/networks/retinanet_detector.py
@@ -32,7 +32,6 @@
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
-
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
@@ -113,7 +112,7 @@ class RetinaNetDetector(nn.Module):
- spatial_dims (int) is the spatial dimension of the network, we support both 2D and 3D.
- num_classes (int) is the number of classes, excluding the background.
- - size_divisible (int or Sequene[int]) is the expection on the input image shape.
+ - size_divisible (int or Sequence[int]) is the expectation on the input image shape.
The network needs the input spatial_size to be divisible by size_divisible, length should be 2 or 3.
- cls_key (str) is the key to represent classification in the output dict.
- box_reg_key (str) is the key to represent box regression in the output dict.
@@ -318,13 +317,17 @@ def set_regular_matcher(self, fg_iou_thresh: float, bg_iou_thresh: float, allow_
Args:
fg_iou_thresh: foreground IoU threshold for Matcher, considered as matched if IoU > fg_iou_thresh
bg_iou_thresh: background IoU threshold for Matcher, considered as not matched if IoU < bg_iou_thresh
+ allow_low_quality_matches: if True, produce additional matches
+ for predictions that have only low-quality match candidates.
"""
if fg_iou_thresh < bg_iou_thresh:
raise ValueError(
"Require fg_iou_thresh >= bg_iou_thresh. "
f"Got fg_iou_thresh={fg_iou_thresh}, bg_iou_thresh={bg_iou_thresh}."
)
- self.proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=True)
+ self.proposal_matcher = Matcher(
+ fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=allow_low_quality_matches
+ )
def set_atss_matcher(self, num_candidates: int = 4, center_in_gt: bool = False) -> None:
"""
@@ -422,7 +425,7 @@ def set_box_selector_parameters(
#. For each level, discard boxes with scores less than self.score_thresh.
#. For each level, keep boxes with top self.topk_candidates_per_level scores.
- #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overapping threshold nms_thresh.
+ #. For the whole image, perform non-maximum suppression (NMS) on boxes, with overlapping threshold nms_thresh.
#. For the whole image, keep boxes with top self.detections_per_img scores.
Args:
@@ -614,7 +617,7 @@ def postprocess_detections(
A = self.num_anchors_per_loc.
Return:
- a list of dict, each dict scorresponds to detection result on image.
+ a list of dict, each dict corresponds to detection result on image.
"""
# recover level sizes, HWA or HWDA for each level
@@ -732,7 +735,7 @@ def compute_anchor_matched_idxs(
# or a negative value indicating that anchor i could not be matched.
# BELOW_LOW_THRESHOLD = -1, BETWEEN_THRESHOLDS = -2
if isinstance(self.proposal_matcher, Matcher):
- # if torcvision matcher
+ # if torchvision matcher
match_quality_matrix = self.box_overlap_metric(
targets_per_image[self.target_box_key].to(anchors_per_image.device), anchors_per_image
)
diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py
index 4539a913acc..4a0d8dc2283 100644
--- a/monai/apps/detection/networks/retinanet_network.py
+++ b/monai/apps/detection/networks/retinanet_network.py
@@ -32,7 +32,6 @@
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
-
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
diff --git a/monai/apps/detection/transforms/box_ops.py b/monai/apps/detection/transforms/box_ops.py
index 64f45d1b671..d2445577d03 100644
--- a/monai/apps/detection/transforms/box_ops.py
+++ b/monai/apps/detection/transforms/box_ops.py
@@ -328,7 +328,7 @@ def select_labels(
labels: Union[Sequence[NdarrayOrTensor], NdarrayOrTensor], keep: NdarrayOrTensor
) -> Union[Tuple, NdarrayOrTensor]:
"""
- For element in labels, select indice keep from it.
+ For element in labels, select indices keep from it.
Args:
labels: Sequence of array. Each element represents classification labels or scores
@@ -342,10 +342,10 @@ def select_labels(
labels_select_list = []
keep_t: torch.Tensor = convert_data_type(keep, torch.Tensor)[0]
- for i in range(len(labels_tuple)):
- labels_t: torch.Tensor = convert_data_type(labels_tuple[i], torch.Tensor)[0]
+ for item in labels_tuple:
+ labels_t: torch.Tensor = convert_data_type(item, torch.Tensor)[0]
labels_t = labels_t[keep_t, ...]
- labels_select_list.append(convert_to_dst_type(src=labels_t, dst=labels_tuple[i])[0])
+ labels_select_list.append(convert_to_dst_type(src=labels_t, dst=item)[0])
if isinstance(labels, (torch.Tensor, np.ndarray)):
return labels_select_list[0] # type: ignore
@@ -372,8 +372,9 @@ def swapaxes_boxes(boxes: NdarrayOrTensor, axis1: int, axis2: int):
boxes_swap = boxes.clone()
else:
boxes_swap = deepcopy(boxes) # type: ignore
- boxes_swap[:, [axis1, axis2]] = boxes_swap[:, [axis2, axis1]] # type: ignore
- boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[ # type: ignore
+ boxes_swap[:, [axis1, axis2]] = boxes_swap[:, [axis2, axis1]]
+
+ boxes_swap[:, [spatial_dims + axis1, spatial_dims + axis2]] = boxes_swap[
:, [spatial_dims + axis2, spatial_dims + axis1]
]
return boxes_swap
diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py
index 8d4f08d2820..fa365895b5a 100644
--- a/monai/apps/detection/transforms/dictionary.py
+++ b/monai/apps/detection/transforms/dictionary.py
@@ -262,7 +262,8 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> Tuple[Ndarray
f"'affine' is not found in {meta_key}. \
Please check whether it is the correct the image meta key."
)
- affine: NdarrayOrTensor = meta_dict["affine"] # type: ignore
+ affine: NdarrayOrTensor = meta_dict["affine"]
+
if self.affine_lps_to_ras: # RAS affine
affine = orientation_ras_lps(affine)
@@ -522,14 +523,14 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
return d
self.randomize(None)
# all the keys share the same random zoom factor
- self.rand_zoom.randomize(d[first_key]) # type: ignore
+ self.rand_zoom.randomize(d[first_key])
# zoom box
for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
@@ -1108,7 +1109,7 @@ def randomize( # type: ignore
thresh_image: Optional[NdarrayOrTensor] = None,
) -> None:
if fg_indices is None or bg_indices is None:
- # We don't require crop center to be whthin the boxes.
+ # We don't require crop center to be within the boxes.
# As along as the cropped patch contains a box, it is considered as a foreground patch.
# Positions within extended_boxes are crop centers for foreground patches
extended_boxes_np = self.generate_fg_center_boxes_np(boxes, image_size)
diff --git a/monai/apps/detection/utils/ATSS_matcher.py b/monai/apps/detection/utils/ATSS_matcher.py
index f0170422bb9..c208fcd41c0 100644
--- a/monai/apps/detection/utils/ATSS_matcher.py
+++ b/monai/apps/detection/utils/ATSS_matcher.py
@@ -59,7 +59,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/matcher.py
@@ -74,7 +73,7 @@
"""
import logging
-from abc import ABC
+from abc import ABC, abstractmethod
from typing import Callable, Sequence, Tuple, TypeVar
import torch
@@ -139,6 +138,7 @@ def __call__(
num_anchors_per_loc=num_anchors_per_loc,
)
+ @abstractmethod
def compute_matches(
self, boxes: torch.Tensor, anchors: torch.Tensor, num_anchors_per_level: Sequence[int], num_anchors_per_loc: int
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -281,7 +281,6 @@ def compute_matches(
matched_vals, matches = ious_inf.to(COMPUTE_DTYPE).max(dim=0)
matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
- # print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
return match_quality_matrix, matches
diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py
index c028228d95e..55c256248a1 100644
--- a/monai/apps/detection/utils/anchor_utils.py
+++ b/monai/apps/detection/utils/anchor_utils.py
@@ -32,7 +32,6 @@
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
-
"""
This script is adapted from
https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py
@@ -171,13 +170,13 @@ def generate_anchors(
scales_t = torch.as_tensor(scales, dtype=dtype, device=device) # sized (N,)
aspect_ratios_t = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) # sized (M,) or (M,2)
if (self.spatial_dims >= 3) and (len(aspect_ratios_t.shape) != 2):
- ValueError(
+ raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should be \
{len(aspect_ratios_t.shape)-1}-D. But got aspect_ratios with shape {aspect_ratios_t.shape}."
)
if (self.spatial_dims >= 3) and (aspect_ratios_t.shape[1] != self.spatial_dims - 1):
- ValueError(
+ raise ValueError(
f"In {self.spatial_dims}-D image, aspect_ratios for each level should has \
shape (_,{self.spatial_dims-1}). But got aspect_ratios with shape {aspect_ratios_t.shape}."
)
diff --git a/monai/apps/detection/utils/box_coder.py b/monai/apps/detection/utils/box_coder.py
index 036e23f2423..6458360fcde 100644
--- a/monai/apps/detection/utils/box_coder.py
+++ b/monai/apps/detection/utils/box_coder.py
@@ -43,7 +43,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""
This script is modified from torchvision to support N-D images,
@@ -173,7 +172,7 @@ def decode(self, rel_codes: Tensor, reference_boxes: Sequence[Tensor]) -> Tensor
Args:
rel_codes: encoded boxes, Nx4 or Nx6 torch tensor.
- boxes: a list of reference boxes, each element is Mx4 or Mx6 torch tensor.
+ reference_boxes: a list of reference boxes, each element is Mx4 or Mx6 torch tensor.
The box mode is assumed to be ``StandardMode``
Return:
diff --git a/monai/apps/detection/utils/box_selector.py b/monai/apps/detection/utils/box_selector.py
index de4de85ea0d..e0e82dbef7c 100644
--- a/monai/apps/detection/utils/box_selector.py
+++ b/monai/apps/detection/utils/box_selector.py
@@ -32,7 +32,6 @@
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
-
"""
Part of this script is adapted from
https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py
@@ -105,7 +104,7 @@ def select_top_score_idx_per_level(self, logits: Tensor) -> Tuple[Tensor, Tensor
"""
Select indices with highest scores.
- The indice selection is performed with the following steps:
+ The indices selection is performed with the following steps:
#. If self.apply_sigmoid, get scores by applying sigmoid to logits. Otherwise, use logits as scores.
#. Discard indices with scores less than self.score_thresh
diff --git a/monai/apps/detection/utils/hard_negative_sampler.py b/monai/apps/detection/utils/hard_negative_sampler.py
index 3067a41a0e6..ee423bb4bcb 100644
--- a/monai/apps/detection/utils/hard_negative_sampler.py
+++ b/monai/apps/detection/utils/hard_negative_sampler.py
@@ -24,21 +24,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/sampler.py
"""
import logging
-from abc import ABC
from typing import List, Tuple
import torch
from torch import Tensor
-class HardNegativeSamplerBase(ABC):
+class HardNegativeSamplerBase:
"""
Base class of hard negative sampler.
@@ -70,11 +68,11 @@ def select_negatives(self, negative: Tensor, num_neg: int, fg_probs: Tensor) ->
where P is the number of negative samples
num_neg: number of negative samples to sample
fg_probs: maximum foreground prediction scores (probability) across all the classes
- for each sample, sized (A,), where A is the the number of samples.
+ for each sample, sized (A,), where A is the number of samples.
Returns:
binary mask of negative samples to choose, sized (A,),
- where A is the the number of samples in one image
+ where A is the number of samples in one image
"""
if negative.numel() > fg_probs.numel():
raise ValueError("The number of negative samples should not be larger than the number of all samples.")
diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py
index f389f7ad33e..6e1770b19ed 100644
--- a/monai/apps/mmars/mmars.py
+++ b/monai/apps/mmars/mmars.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Utilities for accessing Nvidia MMARs
@@ -287,7 +286,7 @@ def load_from_mmar(
model_inst = model_cls()
if pretrained:
_, changed, unchanged = copy_model_state(model_inst, model_dict.get(model_key, model_dict), inplace=True)
- if not (changed and not unchanged): # not all model_inst varaibles are changed
+ if not (changed and not unchanged): # not all model_inst variables are changed
logger.warning(f"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}.")
logger.info("\n---")
doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(item[Keys.NAME], model_prefix="nvidia:med:")
diff --git a/monai/apps/mmars/model_desc.py b/monai/apps/mmars/model_desc.py
index 47bbfe2eaa2..e0a7f26117d 100644
--- a/monai/apps/mmars/model_desc.py
+++ b/monai/apps/mmars/model_desc.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Collection of the remote MMAR descriptors
diff --git a/monai/apps/nuclick/transforms.py b/monai/apps/nuclick/transforms.py
index 6a6b31ba432..d8453ee4f3d 100644
--- a/monai/apps/nuclick/transforms.py
+++ b/monai/apps/nuclick/transforms.py
@@ -145,7 +145,6 @@ def __init__(
min_area: int = 5,
):
- # self.label = label
super().__init__(keys, allow_missing_keys=False)
self.others = others
self.mask_value = mask_value
@@ -431,8 +430,7 @@ def get_patches_and_signals(self, img, click_map, bounding_boxes, cx, cy, m, n,
cx = np.delete(cx, del_indices)
cy = np.delete(cy, del_indices)
- for i in range(len(bounding_boxes)):
- bounding_box = bounding_boxes[i]
+ for i, bounding_box in enumerate(bounding_boxes):
x_start = bounding_box[0]
y_start = bounding_box[1]
x_end = bounding_box[2]
@@ -527,9 +525,9 @@ def post_processing(self, preds, thresh=0.33, min_size=10, min_hole=30, do_recon
def gen_instance_map(self, masks, bounding_boxes, m, n, flatten=True):
instance_map = np.zeros((m, n), dtype=np.uint16)
- for i in range(len(masks)):
+ for i, item in enumerate(masks):
this_bb = bounding_boxes[i]
- this_mask_pos = np.argwhere(masks[i] > 0)
+ this_mask_pos = np.argwhere(item > 0)
this_mask_pos[:, 0] = this_mask_pos[:, 0] + this_bb[1]
this_mask_pos[:, 1] = this_mask_pos[:, 1] + this_bb[0]
instance_map[this_mask_pos[:, 0], this_mask_pos[:, 1]] = 1 if flatten else i + 1
diff --git a/monai/apps/pathology/__init__.py b/monai/apps/pathology/__init__.py
index 81742caf656..da5cec7e6cb 100644
--- a/monai/apps/pathology/__init__.py
+++ b/monai/apps/pathology/__init__.py
@@ -11,6 +11,7 @@
from .data import MaskedInferenceWSIDataset, PatchWSIDataset, SmartCachePatchWSIDataset
from .handlers import ProbMapProducer
+from .losses import HoVerNetLoss
from .metrics import LesionFROC
from .transforms.stain.array import ExtractHEStains, NormalizeHEStains
from .transforms.stain.dictionary import (
diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py
index 71f3214ea4a..1d90615d1ea 100644
--- a/monai/apps/pathology/data/datasets.py
+++ b/monai/apps/pathology/data/datasets.py
@@ -17,11 +17,12 @@
from monai.data import Dataset, SmartCacheDataset
from monai.data.image_reader import WSIReader
-from monai.utils import ensure_tuple_rep
+from monai.utils import deprecated, ensure_tuple_rep
__all__ = ["PatchWSIDataset", "SmartCachePatchWSIDataset", "MaskedInferenceWSIDataset"]
+@deprecated(since="0.8", msg_suffix="use `monai.data.PatchWSIDataset` instead.")
class PatchWSIDataset(Dataset):
"""
This dataset reads whole slide images, extracts regions, and creates patches.
@@ -44,7 +45,7 @@ class PatchWSIDataset(Dataset):
This means from "image1.tiff" extract a region centered at the given location `location`
with the size of `region_size`, and then extract patches with the size of `patch_size`
from a grid with the shape of `grid_shape`.
- Be aware the the `grid_shape` should construct a grid with the same number of element as `labels`,
+ Be aware the `grid_shape` should construct a grid with the same number of element as `labels`,
so for this example the `grid_shape` should be (2, 2).
"""
@@ -103,6 +104,7 @@ def __getitem__(self, index):
return patches
+@deprecated(since="0.8", msg_suffix="use `monai.data.SmartCacheDataset` with `monai.data.PatchWSIDataset` instead.")
class SmartCachePatchWSIDataset(SmartCacheDataset):
"""Add SmartCache functionality to `PatchWSIDataset`.
@@ -177,6 +179,7 @@ def __init__(
)
+@deprecated(since="0.8", msg_suffix="use `monai.data.MaskedPatchWSIDataset` instead.")
class MaskedInferenceWSIDataset(Dataset):
"""
This dataset load the provided foreground masks at an arbitrary resolution level,
diff --git a/monai/apps/pathology/engines/__init__.py b/monai/apps/pathology/engines/__init__.py
new file mode 100644
index 00000000000..68c084d40d7
--- /dev/null
+++ b/monai/apps/pathology/engines/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .utils import PrepareBatchHoVerNet
diff --git a/monai/apps/pathology/engines/utils.py b/monai/apps/pathology/engines/utils.py
new file mode 100644
index 00000000000..3a190a146bf
--- /dev/null
+++ b/monai/apps/pathology/engines/utils.py
@@ -0,0 +1,54 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional, Sequence, Union
+
+import torch
+
+from monai.engines import PrepareBatch, PrepareBatchExtraInput
+from monai.utils import ensure_tuple
+from monai.utils.enums import HoVerNetBranch
+
+__all__ = ["PrepareBatchHoVerNet"]
+
+
+class PrepareBatchHoVerNet(PrepareBatch):
+ """
+ Customized prepare batch callable for trainers or evaluators which support label to be a dictionary.
+ Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).
+ This assumes label is a dictionary.
+
+ Args:
+ extra_keys: If a sequence of strings is provided, values from the input dictionary are extracted from
+ those keys and passed to the nework as extra positional arguments.
+ """
+
+ def __init__(self, extra_keys: Sequence[str]) -> None:
+ if len(ensure_tuple(extra_keys)) != 2:
+ raise ValueError(f"length of `extra_keys` should be 2, get {len(ensure_tuple(extra_keys))}")
+ self.prepare_batch = PrepareBatchExtraInput(extra_keys)
+
+ def __call__(
+ self,
+ batchdata: Dict[str, torch.Tensor],
+ device: Optional[Union[str, torch.device]] = None,
+ non_blocking: bool = False,
+ **kwargs,
+ ):
+ """
+ Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
+ `kwargs` supports other args for `Tensor.to()` API.
+ """
+ image, _label, extra_label, _ = self.prepare_batch(batchdata, device, non_blocking, **kwargs)
+ label = {HoVerNetBranch.NP: _label, HoVerNetBranch.NC: extra_label[0], HoVerNetBranch.HV: extra_label[1]}
+
+ return image, label
diff --git a/monai/apps/pathology/handlers/prob_map_producer.py b/monai/apps/pathology/handlers/prob_map_producer.py
index 62507dc0cb4..d5b1b50c475 100644
--- a/monai/apps/pathology/handlers/prob_map_producer.py
+++ b/monai/apps/pathology/handlers/prob_map_producer.py
@@ -27,7 +27,10 @@
@deprecated(
since="0.8",
- msg_suffix="use `monai.handler.ProbMapProducer` (with `monai.data.wsi_dataset.SlidingPatchWSIDataset`) instead.",
+ msg_suffix=(
+ "use `monai.handler.ProbMapProducer` (with `monai.data.wsi_dataset.MaskedPatchWSIDataset` or "
+ "`monai.data.wsi_dataset.SlidingPatchWSIDataset`) instead."
+ ),
)
class ProbMapProducer:
"""
diff --git a/monai/apps/pathology/losses/__init__.py b/monai/apps/pathology/losses/__init__.py
new file mode 100644
index 00000000000..5e960b34cf6
--- /dev/null
+++ b/monai/apps/pathology/losses/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .hovernet_loss import HoVerNetLoss
diff --git a/monai/apps/pathology/losses/hovernet_loss.py b/monai/apps/pathology/losses/hovernet_loss.py
new file mode 100644
index 00000000000..1133c21fb6e
--- /dev/null
+++ b/monai/apps/pathology/losses/hovernet_loss.py
@@ -0,0 +1,159 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+import torch
+from torch.nn import CrossEntropyLoss
+from torch.nn import functional as F
+from torch.nn.modules.loss import _Loss
+
+from monai.losses import DiceLoss
+from monai.transforms import SobelGradients
+from monai.utils.enums import HoVerNetBranch
+
+
+class HoVerNetLoss(_Loss):
+ """
+ Loss function for HoVerNet pipeline, which is combination of losses across the three branches.
+ The NP (nucleus prediction) branch uses Dice + CrossEntropy.
+ The HV (Horizontal and Vertical) distance from centroid branch uses MSE + MSE of the gradient.
+ The NC (Nuclear Class prediction) branch uses Dice + CrossEntropy
+ The result is a weighted sum of these losses.
+
+ Args:
+ lambda_hv_mse: Weight factor to apply to the HV regression MSE part of the overall loss
+ lambda_hv_mse_grad: Weight factor to apply to the MSE of the HV gradient part of the overall loss
+ lambda_np_ce: Weight factor to apply to the nuclei prediction CrossEntropyLoss part
+ of the overall loss
+ lambda_np_dice: Weight factor to apply to the nuclei prediction DiceLoss part of overall loss
+ lambda_nc_ce: Weight factor to apply to the nuclei class prediction CrossEntropyLoss part
+ of the overall loss
+ lambda_nc_dice: Weight factor to apply to the nuclei class prediction DiceLoss part of the
+ overall loss
+
+ """
+
+ def __init__(
+ self,
+ lambda_hv_mse: float = 2.0,
+ lambda_hv_mse_grad: float = 1.0,
+ lambda_np_ce: float = 1.0,
+ lambda_np_dice: float = 1.0,
+ lambda_nc_ce: float = 1.0,
+ lambda_nc_dice: float = 1.0,
+ ) -> None:
+ self.lambda_hv_mse = lambda_hv_mse
+ self.lambda_hv_mse_grad = lambda_hv_mse_grad
+ self.lambda_np_ce = lambda_np_ce
+ self.lambda_np_dice = lambda_np_dice
+ self.lambda_nc_ce = lambda_nc_ce
+ self.lambda_nc_dice = lambda_nc_dice
+ super().__init__()
+
+ self.dice = DiceLoss(softmax=True, smooth_dr=1e-03, smooth_nr=1e-03, reduction="sum", batch=True)
+ self.ce = CrossEntropyLoss(reduction="mean")
+ self.sobel = SobelGradients(kernel_size=5)
+
+ def _compute_sobel(self, image: torch.Tensor) -> torch.Tensor:
+
+ batch_size = image.shape[0]
+ result_h = self.sobel(torch.squeeze(image[:, 0], dim=1))[batch_size:]
+ result_v = self.sobel(torch.squeeze(image[:, 1], dim=1))[:batch_size]
+
+ return torch.cat([result_h[:, None, ...], result_v[:, None, ...]], dim=1)
+
+ def _mse_gradient_loss(self, prediction: torch.Tensor, target: torch.Tensor, focus: torch.Tensor) -> torch.Tensor:
+ """Compute the MSE loss of the gradients of the horizontal and vertical centroid distance maps"""
+
+ pred_grad = self._compute_sobel(prediction)
+ true_grad = self._compute_sobel(target)
+
+ loss = pred_grad - true_grad
+
+ # The focus constrains the loss computation to the detected nuclear regions
+ # (i.e. background is excluded)
+ focus = focus[:, None, ...]
+ focus = torch.cat((focus, focus), 1)
+
+ loss = focus * (loss * loss)
+ loss = loss.sum() / (focus.sum() + 1.0e-8)
+
+ return loss
+
+ def forward(self, prediction: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ prediction: dictionary of predicted outputs for three branches,
+ each of which should have the shape of BNHW.
+ target: dictionary of ground truths for three branches,
+ each of which should have the shape of BNHW.
+ """
+
+ if not (HoVerNetBranch.NP.value in prediction and HoVerNetBranch.HV.value in prediction):
+ raise ValueError(
+ "nucleus prediction (NP) and horizontal_vertical (HV) branches must be "
+ "present for prediction and target parameters"
+ )
+ if not (HoVerNetBranch.NP.value in target and HoVerNetBranch.HV.value in target):
+ raise ValueError(
+ "nucleus prediction (NP) and horizontal_vertical (HV) branches must be "
+ "present for prediction and target parameters"
+ )
+ if HoVerNetBranch.NC.value not in target and HoVerNetBranch.NC.value in target:
+ raise ValueError(
+ "type_prediction (NC) must be present in both or neither of the prediction and target parameters"
+ )
+ if HoVerNetBranch.NC.value in target and HoVerNetBranch.NC.value not in target:
+ raise ValueError(
+ "type_prediction (NC) must be present in both or neither of the prediction and target parameters"
+ )
+
+ # Compute the NP branch loss
+ dice_loss_np = (
+ self.dice(prediction[HoVerNetBranch.NP.value], target[HoVerNetBranch.NP.value]) * self.lambda_np_dice
+ )
+ # convert to target class indices
+ argmax_target = target[HoVerNetBranch.NP.value].argmax(dim=1)
+ ce_loss_np = self.ce(prediction[HoVerNetBranch.NP.value], argmax_target) * self.lambda_np_ce
+ loss_np = dice_loss_np + ce_loss_np
+
+ # Compute the HV branch loss
+ loss_hv_mse = (
+ F.mse_loss(prediction[HoVerNetBranch.HV.value], target[HoVerNetBranch.HV.value]) * self.lambda_hv_mse
+ )
+
+ # Use the nuclei class, one hot encoded, as the mask
+ loss_hv_mse_grad = (
+ self._mse_gradient_loss(
+ prediction[HoVerNetBranch.HV.value],
+ target[HoVerNetBranch.HV.value],
+ target[HoVerNetBranch.NP.value][:, 1],
+ )
+ * self.lambda_hv_mse_grad
+ )
+ loss_hv = loss_hv_mse_grad + loss_hv_mse
+
+ # Compute the NC branch loss
+ loss_nc = 0
+ if HoVerNetBranch.NC.value in prediction:
+ dice_loss_nc = (
+ self.dice(prediction[HoVerNetBranch.NC.value], target[HoVerNetBranch.NC.value]) * self.lambda_nc_dice
+ )
+ # Convert to target class indices
+ argmax_target = target[HoVerNetBranch.NC.value].argmax(dim=1)
+ ce_loss_nc = self.ce(prediction[HoVerNetBranch.NC.value], argmax_target) * self.lambda_nc_ce
+ loss_nc = dice_loss_nc + ce_loss_nc
+
+ # Sum the losses from each branch
+ loss: torch.Tensor = loss_hv + loss_np + loss_nc
+
+ return loss
diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py
index 6073bd0cda3..6c7965bae6e 100644
--- a/monai/apps/pathology/metrics/lesion_froc.py
+++ b/monai/apps/pathology/metrics/lesion_froc.py
@@ -52,7 +52,7 @@ class LesionFROC:
Defaults to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.
nms_sigma: the standard deviation for gaussian filter of non-maximal suppression. Defaults to 0.0.
nms_prob_threshold: the probability threshold of non-maximal suppression. Defaults to 0.5.
- nms_box_size: the box size (in pixel) to be removed around the the pixel for non-maximal suppression.
+ nms_box_size: the box size (in pixel) to be removed around the pixel for non-maximal suppression.
image_reader_name: the name of library to be used for loading whole slide imaging, either CuCIM or OpenSlide.
Defaults to CuCIM.
diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py
index 290c0ba6a84..3e784b8ebf0 100644
--- a/monai/apps/pathology/transforms/__init__.py
+++ b/monai/apps/pathology/transforms/__init__.py
@@ -9,6 +9,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from .post.array import (
+ GenerateDistanceMap,
+ GenerateInstanceBorder,
+ GenerateInstanceCentroid,
+ GenerateInstanceContour,
+ GenerateInstanceType,
+ GenerateSuccinctContour,
+ GenerateWatershedMarkers,
+ GenerateWatershedMask,
+ Watershed,
+)
+from .post.dictionary import (
+ GenerateDistanceMapD,
+ GenerateDistanceMapd,
+ GenerateDistanceMapDict,
+ GenerateInstanceBorderD,
+ GenerateInstanceBorderd,
+ GenerateInstanceBorderDict,
+ GenerateInstanceCentroidD,
+ GenerateInstanceCentroidd,
+ GenerateInstanceCentroidDict,
+ GenerateInstanceContourD,
+ GenerateInstanceContourd,
+ GenerateInstanceContourDict,
+ GenerateInstanceTypeD,
+ GenerateInstanceTyped,
+ GenerateInstanceTypeDict,
+ GenerateSuccinctContourD,
+ GenerateSuccinctContourd,
+ GenerateSuccinctContourDict,
+ GenerateWatershedMarkersD,
+ GenerateWatershedMarkersd,
+ GenerateWatershedMarkersDict,
+ GenerateWatershedMaskD,
+ GenerateWatershedMaskd,
+ GenerateWatershedMaskDict,
+ WatershedD,
+ Watershedd,
+ WatershedDict,
+)
from .spatial.array import SplitOnGrid, TileOnGrid
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
from .stain.array import ExtractHEStains, NormalizeHEStains
diff --git a/monai/apps/pathology/transforms/post/__init__.py b/monai/apps/pathology/transforms/post/__init__.py
new file mode 100644
index 00000000000..3e6af77ce6b
--- /dev/null
+++ b/monai/apps/pathology/transforms/post/__init__.py
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .array import (
+ GenerateDistanceMap,
+ GenerateInstanceBorder,
+ GenerateInstanceCentroid,
+ GenerateInstanceContour,
+ GenerateInstanceType,
+ GenerateSuccinctContour,
+ GenerateWatershedMarkers,
+ GenerateWatershedMask,
+ Watershed,
+)
+from .dictionary import (
+ GenerateDistanceMapD,
+ GenerateDistanceMapd,
+ GenerateDistanceMapDict,
+ GenerateInstanceBorderD,
+ GenerateInstanceBorderd,
+ GenerateInstanceBorderDict,
+ GenerateInstanceCentroidD,
+ GenerateInstanceCentroidd,
+ GenerateInstanceCentroidDict,
+ GenerateInstanceContourD,
+ GenerateInstanceContourd,
+ GenerateInstanceContourDict,
+ GenerateInstanceTypeD,
+ GenerateInstanceTyped,
+ GenerateInstanceTypeDict,
+ GenerateSuccinctContourD,
+ GenerateSuccinctContourd,
+ GenerateSuccinctContourDict,
+ GenerateWatershedMarkersD,
+ GenerateWatershedMarkersd,
+ GenerateWatershedMarkersDict,
+ GenerateWatershedMaskD,
+ GenerateWatershedMaskd,
+ GenerateWatershedMaskDict,
+ WatershedD,
+ Watershedd,
+ WatershedDict,
+)
diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py
new file mode 100644
index 00000000000..55ff5311724
--- /dev/null
+++ b/monai/apps/pathology/transforms/post/array.py
@@ -0,0 +1,626 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+
+from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
+from monai.transforms.post.array import Activations, AsDiscrete, RemoveSmallObjects, SobelGradients
+from monai.transforms.transform import Transform
+from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
+from monai.utils import TransformBackends, convert_to_numpy, optional_import
+from monai.utils.misc import ensure_tuple_rep
+from monai.utils.type_conversion import convert_to_dst_type
+
+label, _ = optional_import("scipy.ndimage.measurements", name="label")
+disk, _ = optional_import("skimage.morphology", name="disk")
+opening, _ = optional_import("skimage.morphology", name="opening")
+watershed, _ = optional_import("skimage.segmentation", name="watershed")
+find_contours, _ = optional_import("skimage.measure", name="find_contours")
+centroid, _ = optional_import("skimage.measure", name="centroid")
+
+__all__ = [
+ "Watershed",
+ "GenerateWatershedMask",
+ "GenerateInstanceBorder",
+ "GenerateDistanceMap",
+ "GenerateWatershedMarkers",
+ "GenerateSuccinctContour",
+ "GenerateInstanceContour",
+ "GenerateInstanceCentroid",
+ "GenerateInstanceType",
+]
+
+
+class Watershed(Transform):
+ """
+ Use `skimage.segmentation.watershed` to get instance segmentation results from images.
+ See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed.
+
+ Args:
+ connectivity: An array with the same number of dimensions as image whose non-zero elements indicate
+ neighbors for connection. Following the scipy convention, default is a one-connected array of
+ the dimension of the image.
+ dtype: target data content type to convert, default is np.uint8.
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, connectivity: Optional[int] = 1, dtype: DtypeLike = np.uint8) -> None:
+ self.connectivity = connectivity
+ self.dtype = dtype
+
+ def __call__(
+ self, image: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None, markers: Optional[NdarrayOrTensor] = None
+ ) -> NdarrayOrTensor:
+ """
+ Args:
+ image: image where the lowest value points are labeled first. Shape must be [1, H, W, [D]].
+ mask: optional, the same shape as image. Only points at which mask == True will be labeled.
+ If None (no mask given), it is a volume of all 1s.
+ markers: optional, the same shape as image. The desired number of markers, or an array marking
+ the basins with the values to be assigned in the label matrix. Zero means not a marker.
+ If None (no markers given), the local minima of the image are used as markers.
+ """
+
+ image = convert_to_numpy(image)
+ markers = convert_to_numpy(markers)
+ mask = convert_to_numpy(mask)
+
+ instance_seg = watershed(image, markers=markers, mask=mask, connectivity=self.connectivity)
+
+ return convert_to_dst_type(instance_seg, image, dtype=self.dtype)[0]
+
+
+class GenerateWatershedMask(Transform):
+ """
+ generate mask used in `watershed`. Only points at which mask == True will be labeled.
+
+ Args:
+ softmax: if True, apply a softmax function to the prediction.
+ sigmoid: if True, apply a sigmoid function to the prediction.
+ threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold.
+ remove_small_objects: whether need to remove some objects in the marker. Defaults to True.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ dtype: target data content type to convert, default is np.uint8.
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ softmax: bool = True,
+ sigmoid: bool = False,
+ threshold: Optional[float] = None,
+ remove_small_objects: bool = True,
+ min_size: int = 10,
+ dtype: DtypeLike = np.uint8,
+ ) -> None:
+ if sigmoid and threshold is None:
+ raise ValueError("Threshold is needed when using sigmoid activation.")
+
+ self.dtype = dtype
+ self.activations = Activations(sigmoid=sigmoid, softmax=softmax)
+ self.asdiscrete = AsDiscrete(threshold=threshold, argmax=softmax)
+ if remove_small_objects:
+ self.remove_small_objects = RemoveSmallObjects(min_size=min_size)
+ else:
+ self.remove_small_objects = None # type: ignore
+
+ def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ prob_map: probability map of segmentation, shape must be [C, H, W, [D]]
+ """
+
+ pred = self.activations(prob_map)
+ pred = self.asdiscrete(pred)
+
+ pred = convert_to_numpy(pred)
+
+ pred = label(pred)[0]
+ if self.remove_small_objects:
+ pred = self.remove_small_objects(pred)
+ pred[pred > 0] = 1 # type: ignore
+
+ return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]
+
+
+class GenerateInstanceBorder(Transform):
+ """
+ Generate instance border by hover map. The more parts of the image that cannot be identified as foreground areas,
+ the larger the grey scale value. The grey value of the instance's border will be larger.
+
+ Args:
+ kernel_size: the size of the Sobel kernel. Defaults to 21.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ remove_small_objects: whether need to remove some objects in segmentation results. Defaults to True.
+ dtype: target data content type to convert, default is np.float32.
+
+
+ Raises:
+ ValueError: when the `mask` shape is not [1, H, W].
+ ValueError: when the `hover_map` shape is not [2, H, W].
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ kernel_size: int = 21,
+ min_size: int = 10,
+ remove_small_objects: bool = True,
+ dtype: DtypeLike = np.float32,
+ ) -> None:
+
+ self.dtype = dtype
+
+ self.sobel_gradient = SobelGradients(kernel_size=kernel_size)
+ if remove_small_objects:
+ self.remove_small_objects = RemoveSmallObjects(min_size=min_size)
+ else:
+ self.remove_small_objects = None # type: ignore
+
+ def __call__(self, mask: NdarrayOrTensor, hover_map: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore
+ """
+ Args:
+ mask: binarized segmentation result. Shape must be [1, H, W].
+ hover_map: horizontal and vertical distances of nuclear pixels to their centres of mass. Shape must be [2, H, W].
+ The first and second channel represent the horizontal and vertical maps respectively. For more details refer
+ to papers: https://arxiv.org/abs/1812.06499.
+
+ Return:
+ Instance border map.
+
+ Raises:
+ ValueError: when the `hover_map` has only one value.
+ ValueError: when the `sobel gradient map` has only one value.
+
+ """
+ if len(mask.shape) != 3 or len(hover_map.shape) != 3:
+ raise ValueError(
+ f"Suppose the mask and hover map should be with shape of [C, H, W], but got {mask.shape}, {hover_map.shape}"
+ )
+ if mask.shape[0] != 1:
+ raise ValueError(f"Suppose the mask only has one channel, but got {mask.shape[0]}")
+ if hover_map.shape[0] != 2:
+ raise ValueError(f"Suppose the hover map only has two channels, but got {hover_map.shape[0]}")
+
+ hover_h = hover_map[0:1, ...]
+ hover_v = hover_map[1:2, ...]
+
+ hover_h_min, hover_h_max = min(hover_h), max(hover_h)
+ hover_v_min, hover_v_max = min(hover_v), max(hover_v)
+ if (hover_h_max - hover_h_min) == 0 or (hover_v_max - hover_v_min) == 0:
+ raise ValueError("Not a valid hover map, please check your input")
+ hover_h = (hover_h - hover_h_min) / (hover_h_max - hover_h_min)
+ hover_v = (hover_v - hover_v_min) / (hover_v_max - hover_v_min)
+ sobelh = self.sobel_gradient(hover_h)[0, ...]
+ sobelv = self.sobel_gradient(hover_v)[1, ...]
+ sobelh_min, sobelh_max = min(sobelh), max(sobelh)
+ sobelv_min, sobelv_max = min(sobelv), max(sobelv)
+ if (sobelh_max - sobelh_min) == 0 or (sobelv_max - sobelv_min) == 0:
+ raise ValueError("Not a valid sobel gradient map")
+ sobelh = 1 - (sobelh - sobelh_min) / (sobelh_max - sobelh_min)
+ sobelv = 1 - (sobelv - sobelv_min) / (sobelv_max - sobelv_min)
+
+ # combine the h & v values using max
+ overall = maximum(sobelh, sobelv)
+ overall = overall - (1 - mask)
+ overall[overall < 0] = 0
+
+ return convert_to_dst_type(overall, mask, dtype=self.dtype)[0]
+
+
+class GenerateDistanceMap(Transform):
+ """
+ Generate distance map.
+ In general, the instance map is calculated from the distance to the background.
+ Here, we use 1 - "instance border map" to generate the distance map.
+ Nuclei values form mountains so inverse to get basins.
+
+ Args:
+ smooth_fn: execute smooth function on distance map. Defaults to None. You can specify
+ callable functions for smoothing.
+ For example, if you want apply gaussian smooth, you can specify `smooth_fn = GaussianSmooth()`
+ dtype: target data content type to convert, default is np.float32.
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, smooth_fn: Optional[Callable] = None, dtype: DtypeLike = np.float32) -> None:
+ self.smooth_fn = smooth_fn
+ self.dtype = dtype
+
+ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore
+ """
+ Args:
+ mask: binarized segmentation result. Shape must be [1, H, W].
+ instance_border: foreground probability map. Shape must be [1, H, W].
+ """
+ if mask.shape[0] != 1 or mask.ndim != 3:
+ raise ValueError(f"Input mask should be with size of [1, H, W], but got {mask.shape}")
+ if instance_border.shape[0] != 1 or instance_border.ndim != 3:
+ raise ValueError(f"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}")
+
+ distance_map = (1.0 - instance_border) * mask
+
+ if callable(self.smooth_fn):
+ distance_map = self.smooth_fn(distance_map)
+
+ return convert_to_dst_type(-distance_map, mask, dtype=self.dtype)[0]
+
+
+class GenerateWatershedMarkers(Transform):
+ """
+ Generate markers to be used in `watershed`. The watershed algorithm treats pixels values as a local topography
+ (elevation). The algorithm floods basins from the markers until basins attributed to different markers meet on
+ watershed lines. Generally, markers are chosen as local minima of the image, from which basins are flooded.
+ Here is the implementation from HoVerNet papar.
+ For more details refer to papers: https://arxiv.org/abs/1812.06499.
+
+ Args:
+ threshold: threshold the float values of foreground probability map to int 0 or 1 with specified theashold.
+ It turns uncertain area to 1 and other area to 0. Defaults to 0.4.
+ radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ remove_small_objects: whether need to remove some objects in the marker. Defaults to True.
+ postprocess_fn: execute additional post transformation on marker. Defaults to None.
+ dtype: target data content type to convert, default is np.uint8.
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ threshold: float = 0.4,
+ radius: int = 2,
+ min_size: int = 10,
+ remove_small_objects: bool = True,
+ postprocess_fn: Optional[Callable] = None,
+ dtype: DtypeLike = np.uint8,
+ ) -> None:
+ self.threshold = threshold
+ self.radius = radius
+ self.postprocess_fn = postprocess_fn
+ self.dtype = dtype
+
+ if remove_small_objects:
+ self.remove_small_objects = RemoveSmallObjects(min_size=min_size)
+
+ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> NdarrayOrTensor: # type: ignore
+ """
+ Args:
+ mask: binarized segmentation result. Shape must be [1, H, W].
+ instance_border: instance border map. Shape must be [1, H, W].
+ """
+ if mask.shape[0] != 1 or mask.ndim != 3:
+ raise ValueError(f"Input mask should be with size of [1, H, W], but got {mask.shape}")
+ if instance_border.shape[0] != 1 or instance_border.ndim != 3:
+ raise ValueError(f"Input instance_border should be with size of [1, H, W], but got {instance_border.shape}")
+
+ instance_border = instance_border >= self.threshold # uncertain area
+
+ marker = mask - convert_to_dst_type(instance_border, mask, np.uint8)[0] # certain foreground
+ marker[marker < 0] = 0 # type: ignore
+ if self.postprocess_fn:
+ marker = self.postprocess_fn(marker)
+
+ marker = convert_to_numpy(marker)
+
+ marker = opening(marker.squeeze(), disk(self.radius))
+ marker = label(marker)[0]
+ if self.remove_small_objects:
+ marker = self.remove_small_objects(marker[None])
+
+ return convert_to_dst_type(marker, mask, dtype=self.dtype)[0]
+
+
+class GenerateSuccinctContour(Transform):
+ """
+ Converts Scipy-style contours(generated by skimage.measure.find_contours) to a more succinct version which only includes
+ the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line).
+
+ Args:
+ height: height of bounding box, used to detect direction of line segment.
+ width: width of bounding box, used to detect direction of line segment.
+
+ Returns:
+ the pixels that need to be joined by straight lines to describe the outmost pixels of the foreground similar to
+ OpenCV's cv.CHAIN_APPROX_SIMPLE (counterclockwise)
+ """
+
+ def __init__(self, height: int, width: int) -> None:
+ self.height = height
+ self.width = width
+
+ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) -> Tuple[int, int]:
+ """
+ Generate contour coordinates. Given the previous and current coordinates of border positions,
+ returns the int pixel that marks the extremity of the segmented pixels.
+
+ Args:
+ current: coordinates of the current border position.
+ previous: coordinates of the previous border position.
+ """
+
+ p_delta = (current[0] - previous[0], current[1] - previous[1])
+
+ if p_delta == (0.0, 1.0) or p_delta == (0.5, 0.5) or p_delta == (1.0, 0.0):
+ row = int(current[0] + 0.5)
+ col = int(current[1])
+ elif p_delta == (0.0, -1.0) or p_delta == (0.5, -0.5):
+ row = int(current[0])
+ col = int(current[1])
+ elif p_delta == (-1, 0.0) or p_delta == (-0.5, -0.5):
+ row = int(current[0])
+ col = int(current[1] + 0.5)
+ elif p_delta == (-0.5, 0.5):
+ row = int(current[0] + 0.5)
+ col = int(current[1] + 0.5)
+
+ return row, col
+
+ def _calculate_distance_from_topleft(self, sequence: Sequence[Tuple[int, int]]) -> int:
+ """
+ Each sequence of coordinates describes a boundary between foreground and background starting and ending at two sides
+ of the bounding box. To order the sequences correctly, we compute the distance from the topleft of the bounding box
+ around the perimeter in a clockwise direction.
+
+ Args:
+ sequence: list of border points coordinates.
+
+ Returns:
+ the distance round the perimeter of the bounding box from the top-left origin
+ """
+ distance: int
+ first_coord = sequence[0]
+ if first_coord[0] == 0:
+ distance = first_coord[1]
+ elif first_coord[1] == self.width - 1:
+ distance = self.width + first_coord[0]
+ elif first_coord[0] == self.height - 1:
+ distance = 2 * self.width + self.height - first_coord[1]
+ else:
+ distance = 2 * (self.width + self.height) - first_coord[0]
+
+ return distance
+
+ def __call__(self, contours: List[np.ndarray]) -> np.ndarray:
+ """
+ Args:
+ contours: list of (n, 2)-ndarrays, scipy-style clockwise line segments, with lines separating foreground/background.
+ Each contour is an ndarray of shape (n, 2), consisting of n (row, column) coordinates along the contour.
+ """
+ pixels: List[Tuple[int, int]] = []
+ sequences = []
+ corners = [False, False, False, False]
+
+ for group in contours:
+ sequence: List[Tuple[int, int]] = []
+ last_added = None
+ prev = None
+ corner = -1
+
+ for i, coord in enumerate(group):
+ if i == 0:
+ # originating from the top, so must be heading south east
+ if coord[0] == 0.0:
+ corner = 1
+ pixel = (0, int(coord[1] - 0.5))
+ if pixel[1] == self.width - 1:
+ corners[1] = True
+ elif pixel[1] == 0.0:
+ corners[0] = True
+ # originating from the left, so must be heading north east
+ elif coord[1] == 0.0:
+ corner = 0
+ pixel = (int(coord[0] + 0.5), 0)
+ # originating from the bottom, so must be heading north west
+ elif coord[0] == self.height - 1:
+ corner = 3
+ pixel = (int(coord[0]), int(coord[1] + 0.5))
+ if pixel[1] == self.width - 1:
+ corners[2] = True
+ # originating from the right, so must be heading south west
+ elif coord[1] == self.width - 1:
+ corner = 2
+ pixel = (int(coord[0] - 0.5), int(coord[1]))
+ sequence.append(pixel)
+ last_added = pixel
+ elif i == len(group) - 1:
+ # add this point
+ pixel = self._generate_contour_coord(coord, prev) # type: ignore
+ if pixel != last_added:
+ sequence.append(pixel)
+ last_added = pixel
+ elif np.any(coord - prev != group[i + 1] - coord):
+ pixel = self._generate_contour_coord(coord, prev) # type: ignore
+ if pixel != last_added:
+ sequence.append(pixel)
+ last_added = pixel
+
+ # flag whether each corner has been crossed
+ if i == len(group) - 1:
+ if corner == 0:
+ if coord[0] == 0:
+ corners[corner] = True
+ elif corner == 1:
+ if coord[1] == self.width - 1:
+ corners[corner] = True
+ elif corner == 2:
+ if coord[0] == self.height - 1:
+ corners[corner] = True
+ elif corner == 3:
+ if coord[1] == 0.0:
+ corners[corner] = True
+
+ prev = coord
+ dist = self._calculate_distance_from_topleft(sequence)
+
+ sequences.append({"distance": dist, "sequence": sequence})
+
+ # check whether we need to insert any missing corners
+ if corners[0] is False:
+ sequences.append({"distance": 0, "sequence": [(0, 0)]})
+ if corners[1] is False:
+ sequences.append({"distance": self.width, "sequence": [(0, self.width - 1)]})
+ if corners[2] is False:
+ sequences.append({"distance": self.width + self.height, "sequence": [(self.height - 1, self.width - 1)]})
+ if corners[3] is False:
+ sequences.append({"distance": 2 * self.width + self.height, "sequence": [(self.height - 1, 0)]})
+
+ # join the sequences into a single contour
+ # starting at top left and rotating clockwise
+ sequences.sort(key=lambda x: x.get("distance")) # type: ignore
+
+ last = (-1, -1)
+ for _sequence in sequences:
+ if _sequence["sequence"][0] == last: # type: ignore
+ pixels.pop()
+ if pixels:
+ pixels = [*pixels, *_sequence["sequence"]] # type: ignore
+ else:
+ pixels = _sequence["sequence"] # type: ignore
+ last = pixels[-1]
+
+ if pixels[0] == last:
+ pixels.pop(0)
+
+ if pixels[0] == (0, 0):
+ pixels.append(pixels.pop(0))
+
+ return np.flip(convert_to_numpy(pixels, dtype=np.int32)) # type: ignore
+
+
+class GenerateInstanceContour(Transform):
+ """
+ Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include
+ the pixels to which lines need to be drawn
+
+ Args:
+ points_num: assumed that the created contour does not form a contour if it does not contain more points
+ than the specified value. Defaults to 3.
+ level: optional. Value along which to find contours in the array. By default, the level is set
+ to (max(image) + min(image)) / 2.
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, points_num: int = 3, level: Optional[float] = None) -> None:
+ self.level = level
+ self.points_num = points_num
+
+ def __call__(self, image: NdarrayOrTensor, offset: Optional[Sequence[int]] = (0, 0)) -> np.ndarray:
+ """
+ Args:
+ image: instance-level segmentation result. Shape should be [C, H, W]
+ offset: optional, offset of starting position of the instance in the array, default is (0, 0).
+ """
+ image = image.squeeze() # squeeze channel dim
+ image = convert_to_numpy(image)
+ inst_contour_cv = find_contours(image, level=self.level)
+ generate_contour = GenerateSuccinctContour(image.shape[0], image.shape[1])
+ inst_contour = generate_contour(inst_contour_cv)
+
+ # < `self.points_num` points don't make a contour, so skip, likely artifact too
+ # as the contours obtained via approximation => too small or sthg
+ if inst_contour.shape[0] < self.points_num:
+ print(f"< {self.points_num} points don't make a contour, so skip")
+ return None # type: ignore
+ # check for tricky shape
+ elif len(inst_contour.shape) != 2:
+ print(f"{len(inst_contour.shape)} != 2, check for tricky shape")
+ return None # type: ignore
+ else:
+ inst_contour[:, 0] += offset[0] # type: ignore
+ inst_contour[:, 1] += offset[1] # type: ignore
+ return inst_contour
+
+
+class GenerateInstanceCentroid(Transform):
+ """
+ Generate instance centroid using `skimage.measure.centroid`.
+
+ Args:
+ dtype: the data type of output centroid.
+
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, dtype: Optional[DtypeLike] = int) -> None:
+ self.dtype = dtype
+
+ def __call__(self, image: NdarrayOrTensor, offset: Union[Sequence[int], int] = 0) -> np.ndarray:
+ """
+ Args:
+ image: instance-level segmentation result. Shape should be [1, H, W, [D]]
+ offset: optional, offset of starting position of the instance in the array, default is 0 for each dim.
+
+ """
+ image = convert_to_numpy(image)
+ image = image.squeeze(0) # squeeze channel dim
+ ndim = len(image.shape)
+ offset = ensure_tuple_rep(offset, ndim)
+
+ inst_centroid = centroid(image)
+ for i in range(ndim):
+ inst_centroid[i] += offset[i]
+
+ return convert_to_dst_type(inst_centroid, image, dtype=self.dtype)[0] # type: ignore
+
+
+class GenerateInstanceType(Transform):
+ """
+ Generate instance type and probability for each instance.
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def __call__( # type: ignore
+ self, type_pred: NdarrayOrTensor, seg_pred: NdarrayOrTensor, bbox: np.ndarray, instance_id: int
+ ) -> Tuple[int, float]:
+ """
+ Args:
+ type_pred: pixel-level type prediction map after activation function.
+ seg_pred: pixel-level segmentation prediction map after activation function.
+ bbox: bounding box coordinates of the instance, shape is [channel, 2 * spatial dims].
+ instance_id: get instance type from specified instance id.
+ """
+
+ rmin, rmax, cmin, cmax = bbox.flatten()
+ seg_map_crop = seg_pred[0, rmin:rmax, cmin:cmax]
+ type_map_crop = type_pred[0, rmin:rmax, cmin:cmax]
+
+ seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]
+
+ inst_type = type_map_crop[seg_map_crop] # type: ignore
+ type_list, type_pixels = unique(inst_type, return_counts=True)
+ type_list = list(zip(type_list, type_pixels))
+ type_list = sorted(type_list, key=lambda x: x[1], reverse=True) # type: ignore
+ inst_type = type_list[0][0]
+ if inst_type == 0: # ! pick the 2nd most dominant if exist
+ if len(type_list) > 1:
+ inst_type = type_list[1][0]
+ type_dict = {v[0]: v[1] for v in type_list}
+ type_prob = type_dict[inst_type] / (sum(seg_map_crop) + 1.0e-6)
+
+ return (int(inst_type), float(type_prob))
diff --git a/monai/apps/pathology/transforms/post/dictionary.py b/monai/apps/pathology/transforms/post/dictionary.py
new file mode 100644
index 00000000000..c358eebf397
--- /dev/null
+++ b/monai/apps/pathology/transforms/post/dictionary.py
@@ -0,0 +1,491 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Dict, Hashable, Mapping, Optional
+
+import numpy as np
+
+from monai.apps.pathology.transforms.post.array import (
+ GenerateDistanceMap,
+ GenerateInstanceBorder,
+ GenerateInstanceCentroid,
+ GenerateInstanceContour,
+ GenerateInstanceType,
+ GenerateSuccinctContour,
+ GenerateWatershedMarkers,
+ GenerateWatershedMask,
+ Watershed,
+)
+from monai.config.type_definitions import DtypeLike, KeysCollection, NdarrayOrTensor
+from monai.transforms.transform import MapTransform
+from monai.utils import optional_import
+
+find_contours, _ = optional_import("skimage.measure", name="find_contours")
+moments, _ = optional_import("skimage.measure", name="moments")
+
+__all__ = [
+ "WatershedD",
+ "WatershedDict",
+ "Watershedd",
+ "GenerateWatershedMaskD",
+ "GenerateWatershedMaskDict",
+ "GenerateWatershedMaskd",
+ "GenerateInstanceBorderD",
+ "GenerateInstanceBorderDict",
+ "GenerateInstanceBorderd",
+ "GenerateDistanceMapD",
+ "GenerateDistanceMapDict",
+ "GenerateDistanceMapd",
+ "GenerateWatershedMarkersD",
+ "GenerateWatershedMarkersDict",
+ "GenerateWatershedMarkersd",
+ "GenerateSuccinctContourDict",
+ "GenerateSuccinctContourD",
+ "GenerateSuccinctContourd",
+ "GenerateInstanceContourDict",
+ "GenerateInstanceContourD",
+ "GenerateInstanceContourd",
+ "GenerateInstanceCentroidDict",
+ "GenerateInstanceCentroidD",
+ "GenerateInstanceCentroidd",
+ "GenerateInstanceTypeDict",
+ "GenerateInstanceTypeD",
+ "GenerateInstanceTyped",
+]
+
+
+class Watershedd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.Watershed`.
+ Use `skimage.segmentation.watershed` to get instance segmentation results from images.
+ See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ See also: monai.transforms.MapTransform
+ mask_key: keys of mask used in watershed. Only points at which mask == True will be labeled.
+ markers_key: keys of markers used in watershed. If None (no markers given), the local minima of the image are
+ used as markers.
+ connectivity: An array with the same number of dimensions as image whose non-zero elements indicate neighbors
+ for connection. Following the scipy convention, default is a one-connected array of the dimension of the
+ image.
+ dtype: target data content type to convert. Defaults to np.uint8.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ Raises:
+ ValueError: when the `image` shape is not [1, H, W].
+ ValueError: when the `mask` shape is not [1, H, W].
+
+ """
+
+ backend = Watershed.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ mask_key: Optional[str] = "mask",
+ markers_key: Optional[str] = None,
+ connectivity: Optional[int] = 1,
+ dtype: DtypeLike = np.uint8,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mask_key = mask_key
+ self.markers_key = markers_key
+ self.transform = Watershed(connectivity=connectivity, dtype=dtype)
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ markers = d[self.markers_key] if self.markers_key else None
+ mask = d[self.mask_key] if self.mask_key else None
+
+ for key in self.key_iterator(d):
+ d[key] = self.transform(d[key], mask, markers)
+
+ return d
+
+
+class GenerateWatershedMaskd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMask`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ mask_key: the mask will be written to the value of `{mask_key}`.
+ softmax: if True, apply a softmax function to the prediction.
+ sigmoid: if True, apply a sigmoid function to the prediction.
+ threshold: if not None, threshold the float values to int number 0 or 1 with specified theashold.
+ remove_small_objects: whether need to remove some objects in the marker. Defaults to True.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ dtype: target data content type to convert. Defaults to np.uint8.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = GenerateWatershedMask.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ mask_key: str = "mask",
+ softmax: bool = True,
+ sigmoid: bool = False,
+ threshold: Optional[float] = None,
+ remove_small_objects: bool = True,
+ min_size: int = 10,
+ dtype: DtypeLike = np.uint8,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mask_key = mask_key
+ self.transform = GenerateWatershedMask(
+ softmax=softmax,
+ sigmoid=sigmoid,
+ threshold=threshold,
+ remove_small_objects=remove_small_objects,
+ min_size=min_size,
+ dtype=dtype,
+ )
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ mask = self.transform(d[key])
+ key_to_add = f"{self.mask_key}"
+ if key_to_add in d:
+ raise KeyError(f"Mask with key {key_to_add} already exists.")
+ d[key_to_add] = mask
+ return d
+
+
+class GenerateInstanceBorderd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateInstanceBorder`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ hover_map_key: keys of hover map used to generate probability map.
+ border_key: the instance border map will be written to the value of `{border_key}`.
+ kernel_size: the size of the Sobel kernel. Defaults to 21.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ remove_small_objects: whether need to remove some objects in segmentation results. Defaults to True.
+ dtype: target data content type to convert, default is np.float32.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ Raises:
+ ValueError: when the `hover_map` has only one value.
+ ValueError: when the `sobel gradient map` has only one value.
+
+ """
+
+ backend = GenerateInstanceBorder.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ hover_map_key: str = "hover_map",
+ border_key: str = "border",
+ kernel_size: int = 21,
+ min_size: int = 10,
+ remove_small_objects: bool = True,
+ dtype: DtypeLike = np.float32,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.hover_map_key = hover_map_key
+ self.border_key = border_key
+ self.transform = GenerateInstanceBorder(
+ kernel_size=kernel_size, remove_small_objects=remove_small_objects, min_size=min_size, dtype=dtype
+ )
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ instance_border = self.transform(d[key], d[self.hover_map_key])
+ key_to_add = f"{self.border_key}"
+ if key_to_add in d:
+ raise KeyError(f"Instance border map with key {key_to_add} already exists.")
+ d[key_to_add] = instance_border
+ return d
+
+
+class GenerateDistanceMapd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateDistanceMap`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ border_key: keys of the instance border map used to generate distance map.
+ dist_key: the distance map will be written to the value of `{dist_key}`.
+ smooth_fn: execute smooth function on distance map. Defaults to None. You can specify
+ callable functions for smoothing.
+ For example, if you want apply gaussian smooth, you can specify `smooth_fn = GaussianSmooth()`
+ dtype: target data content type to convert, default is np.float32.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ backend = GenerateDistanceMap.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ border_key: str = "border",
+ dist_key: str = "dist",
+ smooth_fn: Optional[Callable] = None,
+ dtype: DtypeLike = np.float32,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.border_key = border_key
+ self.dist_key = dist_key
+ self.transform = GenerateDistanceMap(smooth_fn=smooth_fn, dtype=dtype)
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ distance_map = self.transform(d[key], d[self.border_key])
+ key_to_add = f"{self.dist_key}"
+ if key_to_add in d:
+ raise KeyError(f"Distance map with key {key_to_add} already exists.")
+ d[key_to_add] = distance_map
+ return d
+
+
+class GenerateWatershedMarkersd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.array.GenerateWatershedMarkers`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ border_key: keys of the instance border map used to generate markers.
+ markers_key: the markers will be written to the value of `{markers_key}`.
+ threshold: threshold the float values of instance border map to int 0 or 1 with specified theashold.
+ It turns uncertain area to 1 and other area to 0. Defaults to 0.4.
+ radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2.
+ min_size: objects smaller than this size are removed if `remove_small_objects` is True. Defaults to 10.
+ remove_small_objects: whether need to remove some objects in the marker. Defaults to True.
+ postprocess_fn: execute additional post transformation on marker. Defaults to None.
+ dtype: target data content type to convert, default is np.uint8.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ backend = GenerateWatershedMarkers.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ border_key: str = "border",
+ markers_key: str = "markers",
+ threshold: float = 0.4,
+ radius: int = 2,
+ min_size: int = 10,
+ remove_small_objects: bool = True,
+ postprocess_fn: Optional[Callable] = None,
+ dtype: DtypeLike = np.uint8,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.border_key = border_key
+ self.markers_key = markers_key
+ self.transform = GenerateWatershedMarkers(
+ threshold=threshold,
+ radius=radius,
+ min_size=min_size,
+ remove_small_objects=remove_small_objects,
+ postprocess_fn=postprocess_fn,
+ dtype=dtype,
+ )
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ markers = self.transform(d[key], d[self.border_key])
+ key_to_add = f"{self.markers_key}"
+ if key_to_add in d:
+ raise KeyError(f"Markers with key {key_to_add} already exists.")
+ d[key_to_add] = markers
+ return d
+
+
+class GenerateSuccinctContourd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateSuccinctContour`.
+ Converts Scipy-style contours(generated by skimage.measure.find_contours) to a more succinct version which
+ only includes the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line).
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ height: height of bounding box, used to detect direction of line segment.
+ width: width of bounding box, used to detect direction of line segment.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = GenerateSuccinctContour.backend
+
+ def __init__(self, keys: KeysCollection, height: int, width: int, allow_missing_keys: bool = False) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = GenerateSuccinctContour(height=height, width=width)
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.converter(d[key])
+
+ return d
+
+
+class GenerateInstanceContourd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceContour`.
+ Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include the pixels
+ to which lines need to be drawn
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ contour_key_postfix: the output contour coordinates will be written to the value of
+ `{key}_{contour_key_postfix}`.
+ offset_key: keys of offset used in `GenerateInstanceContour`.
+ points_num: assumed that the created contour does not form a contour if it does not contain more points
+ than the specified value. Defaults to 3.
+ level: optional. Value along which to find contours in the array. By default, the level is set
+ to (max(image) + min(image)) / 2.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = GenerateInstanceContour.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ contour_key_postfix: str = "contour",
+ offset_key: Optional[str] = None,
+ points_num: int = 3,
+ level: Optional[float] = None,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = GenerateInstanceContour(points_num=points_num, level=level)
+ self.contour_key_postfix = contour_key_postfix
+ self.offset_key = offset_key
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ offset = d[self.offset_key] if self.offset_key else None
+ contour = self.converter(d[key], offset)
+ key_to_add = f"{key}_{self.contour_key_postfix}"
+ if key_to_add in d:
+ raise KeyError(f"Contour with key {key_to_add} already exists.")
+ d[key_to_add] = contour
+ return d
+
+
+class GenerateInstanceCentroidd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceCentroid`.
+ Generate instance centroid using `skimage.measure.centroid`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ centroid_key_postfix: the output centroid coordinates will be written to the value of
+ `{key}_{centroid_key_postfix}`.
+ offset_key: keys of offset used in `GenerateInstanceCentroid`.
+ dtype: the data type of output centroid.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = GenerateInstanceCentroid.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ centroid_key_postfix: str = "centroid",
+ offset_key: Optional[str] = None,
+ dtype: Optional[DtypeLike] = int,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = GenerateInstanceCentroid(dtype=dtype)
+ self.centroid_key_postfix = centroid_key_postfix
+ self.offset_key = offset_key
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ offset = d[self.offset_key] if self.offset_key else None
+ centroid = self.converter(d[key], offset)
+ key_to_add = f"{key}_{self.centroid_key_postfix}"
+ if key_to_add in d:
+ raise KeyError(f"Centroid with key {key_to_add} already exists.")
+ d[key_to_add] = centroid
+ return d
+
+
+class GenerateInstanceTyped(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.apps.pathology.transforms.post.array.GenerateInstanceType`.
+ Generate instance type and probability for each instance.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ type_info_key: the output instance type and probability will be written to the value of
+ `{type_info_key}`.
+ bbox_key: keys of bounding box.
+ seg_pred_key: keys of segmentation prediction map.
+ instance_id_key: keys of instance id.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = GenerateInstanceType.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ type_info_key: str = "type_info",
+ bbox_key: str = "bbox",
+ seg_pred_key: str = "seg",
+ instance_id_key: str = "id",
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = GenerateInstanceType()
+ self.type_info_key = type_info_key
+ self.bbox_key = bbox_key
+ self.seg_pred_key = seg_pred_key
+ self.instance_id_key = instance_id_key
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ seg = d[self.seg_pred_key]
+ bbox = d[self.bbox_key]
+ id = d[self.instance_id_key]
+ instance_type, type_prob = self.converter(d[key], seg, bbox, id)
+ key_to_add = f"{self.type_info_key}"
+ if key_to_add in d:
+ raise KeyError(f"Type information with key {key_to_add} already exists.")
+ d[key_to_add] = {"inst_type": instance_type, "type_prob": type_prob}
+ return d
+
+
+WatershedD = WatershedDict = Watershedd
+GenerateWatershedMaskD = GenerateWatershedMaskDict = GenerateWatershedMaskd
+GenerateInstanceBorderD = GenerateInstanceBorderDict = GenerateInstanceBorderd
+GenerateDistanceMapD = GenerateDistanceMapDict = GenerateDistanceMapd
+GenerateWatershedMarkersD = GenerateWatershedMarkersDict = GenerateWatershedMarkersd
+GenerateSuccinctContourDict = GenerateSuccinctContourD = GenerateSuccinctContourd
+GenerateInstanceContourDict = GenerateInstanceContourD = GenerateInstanceContourd
+GenerateInstanceCentroidDict = GenerateInstanceCentroidD = GenerateInstanceCentroidd
+GenerateInstanceTypeDict = GenerateInstanceTypeD = GenerateInstanceTyped
diff --git a/monai/apps/reconstruction/complex_utils.py b/monai/apps/reconstruction/complex_utils.py
index 7eeeffb1b0b..0a5cdccd0d0 100644
--- a/monai/apps/reconstruction/complex_utils.py
+++ b/monai/apps/reconstruction/complex_utils.py
@@ -98,6 +98,21 @@ def convert_to_tensor_complex(
return converted_data
+def complex_abs_t(x: Tensor) -> Tensor:
+ """
+ Compute the absolute value of a complex tensor.
+
+ Args:
+ x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.
+
+ Returns:
+ Absolute value along the last dimension
+ """
+ if x.shape[-1] != 2:
+ raise ValueError(f"x.shape[-1] is not 2 ({x.shape[-1]}).")
+ return (x[..., 0] ** 2 + x[..., 1] ** 2) ** 0.5 # type: ignore
+
+
def complex_abs(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Compute the absolute value of a complex array.
@@ -106,7 +121,7 @@ def complex_abs(x: NdarrayOrTensor) -> NdarrayOrTensor:
x: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.
Returns:
- Absolute value along the last dimention
+ Absolute value along the last dimension
Example:
.. code-block:: python
@@ -116,9 +131,27 @@ def complex_abs(x: NdarrayOrTensor) -> NdarrayOrTensor:
# the following line prints 5
print(complex_abs(x))
"""
- if x.shape[-1] != 2:
- raise ValueError(f"x.shape[-1] is not 2 ({x.shape[-1]}).")
- return (x[..., 0] ** 2 + x[..., 1] ** 2) ** 0.5
+ return complex_abs_t(x) # type: ignore
+
+
+def complex_mul_t(x: Tensor, y: Tensor) -> Tensor:
+ """
+ Compute complex-valued multiplication. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)
+
+ Args:
+ x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.
+ y: Input tensor with 2 channels in the last dimension representing real and imaginary parts.
+
+ Returns:
+ Complex multiplication of x and y
+ """
+ if x.shape[-1] != 2 or y.shape[-1] != 2:
+ raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]} and y.shape[-1] is {y.shape[-1]}.")
+
+ real_part = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
+ imag_part = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
+
+ return torch.stack((real_part, imag_part), dim=-1)
def complex_mul(x: NdarrayOrTensor, y: NdarrayOrTensor) -> NdarrayOrTensor:
@@ -144,20 +177,37 @@ def complex_mul(x: NdarrayOrTensor, y: NdarrayOrTensor) -> NdarrayOrTensor:
if x.shape[-1] != 2 or y.shape[-1] != 2:
raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]} and y.shape[-1] is {y.shape[-1]}.")
- re = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
- im = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
-
if isinstance(x, Tensor):
- return torch.stack((re, im), dim=-1) # type: ignore
+ return complex_mul_t(x, y) # type: ignore
+
else:
- mult: np.ndarray = np.stack((re, im), axis=-1)
+ real_part = x[..., 0] * y[..., 0] - x[..., 1] * y[..., 1]
+ imag_part = x[..., 0] * y[..., 1] + x[..., 1] * y[..., 0]
+
+ mult: np.ndarray = np.stack((real_part, imag_part), axis=-1)
return mult
-def complex_conj(x: NdarrayOrTensor) -> NdarrayOrTensor:
+def complex_conj_t(x: Tensor) -> Tensor:
"""
Compute complex conjugate of a tensor. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)
+ Args:
+ x: Input tensor with 2 channels in the last dimension representing real and imaginary parts.
+
+ Returns:
+ Complex conjugate of x
+ """
+ if x.shape[-1] != 2:
+ raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.")
+
+ return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
+
+
+def complex_conj(x: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Compute complex conjugate of an/a array/tensor. Supports Ndim inputs with last dim equal to 2 (real/imaginary channels)
+
Args:
x: Input array/tensor with 2 channels in the last dimension representing real and imaginary parts.
@@ -176,7 +226,7 @@ def complex_conj(x: NdarrayOrTensor) -> NdarrayOrTensor:
raise ValueError(f"last dim must be 2, but x.shape[-1] is {x.shape[-1]}.")
if isinstance(x, Tensor):
- return torch.stack((x[..., 0], -x[..., 1]), dim=-1)
+ return complex_conj_t(x)
else:
np_conj: np.ndarray = np.stack((x[..., 0], -x[..., 1]), axis=-1)
return np_conj
diff --git a/monai/apps/reconstruction/mri_utils.py b/monai/apps/reconstruction/mri_utils.py
index fad952712d0..9c06b492d56 100644
--- a/monai/apps/reconstruction/mri_utils.py
+++ b/monai/apps/reconstruction/mri_utils.py
@@ -9,9 +9,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from torch import Tensor
+
from monai.config.type_definitions import NdarrayOrTensor
+def root_sum_of_squares_t(x: Tensor, spatial_dim: int) -> Tensor:
+ """
+ Compute the root sum of squares (rss) of the data (typically done for multi-coil MRI samples)
+
+ Args:
+ x: Input tensor
+ spatial_dim: dimension along which rss is applied
+
+ Returns:
+ rss of x along spatial_dim
+
+ Example:
+ .. code-block:: python
+
+ import numpy as np
+ x = torch.ones([2,3])
+ # the following line prints Tensor([1.41421356, 1.41421356, 1.41421356])
+ print(rss(x,spatial_dim=0))
+ """
+ rss_x: Tensor = (x**2).sum(spatial_dim) ** 0.5
+ return rss_x
+
+
def root_sum_of_squares(x: NdarrayOrTensor, spatial_dim: int) -> NdarrayOrTensor:
"""
Compute the root sum of squares (rss) of the data (typically done for multi-coil MRI samples)
@@ -31,5 +56,5 @@ def root_sum_of_squares(x: NdarrayOrTensor, spatial_dim: int) -> NdarrayOrTensor
# the following line prints array([1.41421356, 1.41421356, 1.41421356])
print(rss(x,spatial_dim=0))
"""
- rss_x: NdarrayOrTensor = (x**2).sum(spatial_dim) ** 0.5
+ rss_x: NdarrayOrTensor = root_sum_of_squares_t(x, spatial_dim) # type: ignore
return rss_x
diff --git a/monai/apps/reconstruction/networks/blocks/__init__.py b/monai/apps/reconstruction/networks/blocks/__init__.py
new file mode 100644
index 00000000000..1e97f894078
--- /dev/null
+++ b/monai/apps/reconstruction/networks/blocks/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/apps/reconstruction/networks/blocks/varnetblock.py b/monai/apps/reconstruction/networks/blocks/varnetblock.py
new file mode 100644
index 00000000000..daaa3efbf32
--- /dev/null
+++ b/monai/apps/reconstruction/networks/blocks/varnetblock.py
@@ -0,0 +1,79 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from monai.apps.reconstruction.networks.nets.utils import sensitivity_map_expand, sensitivity_map_reduce
+
+
+class VarNetBlock(nn.Module):
+ """
+ A variational block based on Sriram et. al., "End-to-end variational networks for accelerated MRI reconstruction".
+ It applies data consistency and refinement to the intermediate kspace and combines those results.
+
+ Modified and adopted from: https://github.com/facebookresearch/fastMRI
+
+ Args:
+ refinement_model: the model used for refinement (typically a U-Net but can be any deep learning model
+ that performs well when the input and output are in image domain (e.g., a convolutional network).
+ spatial_dims: is 2 for 2D data and is 3 for 3D data
+ """
+
+ def __init__(self, refinement_model: nn.Module, spatial_dims: int = 2):
+ super().__init__()
+ self.model = refinement_model
+ self.spatial_dims = spatial_dims
+ self.dc_weight = nn.Parameter(torch.ones(1)) # learned scalar as the multiplier of the DC block
+
+ buffer_shape = [1 for _ in range(spatial_dims + 3)] # 3 denotes the batch, channel, and real/complex dimensions
+ self.register_buffer("zeros", torch.zeros(buffer_shape))
+
+ def soft_dc(self, x: Tensor, ref_kspace: Tensor, mask: Tensor) -> Tensor:
+ """
+ Applies data consistency to input x. Suppose x is an intermediate estimate of the kspace and ref_kspace
+ is the reference under-sampled measurement. This function returns mask * (x - ref_kspace). View this as the
+ residual between the original under-sampled kspace and the estimate given by the network.
+
+ Args:
+ x: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the
+ coil dimension. 3D data will have the shape (B,C,H,W,D,2).
+ ref_kspace: original under-sampled kspace with the same shape as x.
+ mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.
+
+ Returns:
+ Output of DC block with the same shape as x
+ """
+ return torch.where(mask, x - ref_kspace, self.zeros) * self.dc_weight # type: ignore
+
+ def forward(self, current_kspace: Tensor, ref_kspace: Tensor, mask: Tensor, sens_maps: Tensor) -> Tensor:
+ """
+ Args:
+ current_kspace: Predicted kspace from the previous block. It's a 2D kspace (B,C,H,W,2)
+ with the last dimension being 2 (for real/imaginary parts) and C denoting the
+ coil dimension. 3D data will have the shape (B,C,H,W,D,2).
+ ref_kspace: reference kspace for applying data consistency (is the under-sampled kspace in MRI reconstruction).
+ Its shape is the same as current_kspace.
+ mask: the under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.
+ sens_maps: coil sensitivity maps with the same shape as current_kspace
+
+ Returns:
+ Output of VarNetBlock with the same shape as current_kspace
+ """
+ dc_out = self.soft_dc(current_kspace, ref_kspace, mask) # output of DC block
+ refinement_out = sensitivity_map_expand(
+ self.model(sensitivity_map_reduce(current_kspace, sens_maps, spatial_dims=self.spatial_dims)),
+ sens_maps,
+ spatial_dims=self.spatial_dims,
+ ) # output of refinement model
+ output = current_kspace - dc_out - refinement_out
+ return output
diff --git a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py
index 208196da02c..94568db90f5 100644
--- a/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py
+++ b/monai/apps/reconstruction/networks/nets/coil_sensitivity_model.py
@@ -9,13 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
-from monai.apps.reconstruction.mri_utils import root_sum_of_squares
+from monai.apps.reconstruction.mri_utils import root_sum_of_squares_t
from monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet
from monai.apps.reconstruction.networks.nets.utils import (
reshape_batch_channel_to_channel_dim,
@@ -83,7 +83,7 @@ def __init__(
self.spatial_dims = spatial_dims
self.coil_dim = coil_dim
- def get_fully_sampled_region(self, mask: Tensor) -> Sequence[int]:
+ def get_fully_sampled_region(self, mask: Tensor) -> Tuple[int, int]:
"""
Extracts the size of the fully-sampled part of the kspace. Note that when a kspace
is under-sampled, a part of its center is fully sampled. This part is called the Auto
@@ -126,14 +126,15 @@ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
# take out the fully-sampled region and set the rest of the data to zero
x = torch.zeros_like(masked_kspace)
start = (mask.shape[-2] - num_low_freqs + 1) // 2 # this marks the start of center extraction
- x[..., start : start + num_low_freqs] = masked_kspace[..., start : start + num_low_freqs]
+ x[..., start : start + num_low_freqs, :] = masked_kspace[..., start : start + num_low_freqs, :]
# apply inverse fourier to the extracted fully-sampled data
- x = ifftn_centered_t(x, spatial_dims=self.spatial_dims)
+ x = ifftn_centered_t(x, spatial_dims=self.spatial_dims, is_complex=True)
x, b = reshape_channel_to_batch_dim(x) # shape of x will be (B*C,1,...)
x = self.conv_net(x)
x = reshape_batch_channel_to_channel_dim(x, b) # shape will be (B,C,...)
# normalize the maps
- x /= root_sum_of_squares(x, spatial_dim=self.coil_dim).unsqueeze(self.coil_dim) # type: ignore
+ x = x / root_sum_of_squares_t(x, spatial_dim=self.coil_dim).unsqueeze(self.coil_dim)
+
return x
diff --git a/monai/apps/reconstruction/networks/nets/complex_unet.py b/monai/apps/reconstruction/networks/nets/complex_unet.py
index c927fffdffc..ccbb5731a1f 100644
--- a/monai/apps/reconstruction/networks/nets/complex_unet.py
+++ b/monai/apps/reconstruction/networks/nets/complex_unet.py
@@ -11,17 +11,17 @@
from typing import Optional, Sequence, Union
-import torch
import torch.nn as nn
from torch import Tensor
from monai.apps.reconstruction.networks.nets.utils import (
complex_normalize,
+ divisible_pad_t,
+ inverse_divisible_pad_t,
reshape_channel_complex_to_last_dim,
reshape_complex_to_channel_dim,
)
from monai.networks.nets.basic_unet import BasicUNet
-from monai.transforms import DivisiblePad
class ComplexUnet(nn.Module):
@@ -84,6 +84,7 @@ def __init__(
if params[0][1] != 2:
raise ValueError(f"in_channels should be 2 but it's {params[0][1]}.")
self.unet = conv_net
+
self.pad_factor = pad_factor
def forward(self, x: Tensor) -> Tensor:
@@ -98,14 +99,13 @@ def forward(self, x: Tensor) -> Tensor:
x = reshape_complex_to_channel_dim(x) # x will be of shape (B,C*2,H,W)
x, mean, std = complex_normalize(x) # x will be of shape (B,C*2,H,W)
# pad input
- padder = DivisiblePad(k=self.pad_factor)
- x = torch.stack(
- [padder(xi) for xi in x]
+ x, padding_sizes = divisible_pad_t(
+ x, k=self.pad_factor
) # x will be of shape (B,C*2,H',W') where H' and W' are for after padding
x = self.unet(x)
# inverse padding
- x = torch.stack([padder.inverse(xi) for xi in x]) # x will be of shape (B,C*2,H,W)
+ x = inverse_divisible_pad_t(x, padding_sizes) # x will be of shape (B,C*2,H,W)
x = x * std + mean
x = reshape_channel_complex_to_last_dim(x) # x will be of shape (B,C,H,W,2)
diff --git a/monai/apps/reconstruction/networks/nets/utils.py b/monai/apps/reconstruction/networks/nets/utils.py
index f8909f2d927..b97cdab7860 100644
--- a/monai/apps/reconstruction/networks/nets/utils.py
+++ b/monai/apps/reconstruction/networks/nets/utils.py
@@ -12,12 +12,17 @@
This script contains utility functions for developing new networks/blocks in PyTorch.
"""
-from typing import Sequence
+import math
+from typing import Tuple
from torch import Tensor
+from torch.nn import functional as F
+from monai.apps.reconstruction.complex_utils import complex_conj_t, complex_mul_t
+from monai.networks.blocks.fft_utils_t import fftn_centered_t, ifftn_centered_t
-def reshape_complex_to_channel_dim(x: Tensor) -> Tensor: # type: ignore
+
+def reshape_complex_to_channel_dim(x: Tensor) -> Tensor:
"""
Swaps the complex dimension with the channel dimension so that the network treats real/imaginary
parts as two separate channels.
@@ -39,8 +44,11 @@ def reshape_complex_to_channel_dim(x: Tensor) -> Tensor: # type: ignore
b, c, h, w, d, two = x.shape
return x.permute(0, 5, 1, 2, 3, 4).contiguous().view(b, 2 * c, h, w, d)
+ else:
+ raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
+
-def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor: # type: ignore
+def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor:
"""
Swaps the complex dimension with the channel dimension so that the network output has 2 as its last dimension
@@ -63,21 +71,33 @@ def reshape_channel_complex_to_last_dim(x: Tensor) -> Tensor: # type: ignore
c = c2 // 2
return x.view(b, 2, c, h, w, d).permute(0, 2, 3, 4, 5, 1)
+ else:
+ raise ValueError(f"only 2D (B,C*2,H,W) and 3D (B,C*2,H,W,D) data are supported but x has shape {x.shape}")
+
-def reshape_channel_to_batch_dim(x: Tensor) -> Sequence:
+def reshape_channel_to_batch_dim(x: Tensor) -> Tuple[Tensor, int]:
"""
Combines batch and channel dimensions.
Args:
- x: Ndim input of shape shape (B,C,...)
+ x: input of shape (B,C,H,W,2) for 2D data or (B,C,H,W,D,2) for 3D data
Returns:
A tuple containing:
(1) output of shape (B*C,1,...)
(2) batch size
"""
- b, c, *other = x.shape
- return x.contiguous().view(b * c, 1, *other), b
+
+ if len(x.shape) == 5: # this is 2D
+ b, c, h, w, two = x.shape
+ return x.contiguous().view(b * c, 1, h, w, two), b
+
+ elif len(x.shape) == 6: # this is 3D
+ b, c, h, w, d, two = x.shape
+ return x.contiguous().view(b * c, 1, h, w, d, two), b
+
+ else:
+ raise ValueError(f"only 2D (B,C,H,W,2) and 3D (B,C,H,W,D,2) data are supported but x has shape {x.shape}")
def reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor:
@@ -85,18 +105,27 @@ def reshape_batch_channel_to_channel_dim(x: Tensor, batch_size: int) -> Tensor:
Detaches batch and channel dimensions.
Args:
- x: Ndim input of shape (B*C,1,...)
+ x: input of shape (B*C,1,H,W,2) for 2D data or (B*C,1,H,W,D,2) for 3D data
batch_size: batch size
Returns:
output of shape (B,C,...)
"""
- bc, one, *other = x.shape # bc represents B*C
- c = bc // batch_size
- return x.view(batch_size, c, *other)
+ if len(x.shape) == 5: # this is 2D
+ bc, one, h, w, two = x.shape # bc represents B*C
+ c = bc // batch_size
+ return x.view(batch_size, c, h, w, two)
+
+ elif len(x.shape) == 6: # this is 3D
+ bc, one, h, w, d, two = x.shape # bc represents B*C
+ c = bc // batch_size
+ return x.view(batch_size, c, h, w, d, two)
+
+ else:
+ raise ValueError(f"only 2D (B*C,1,H,W,2) and 3D (B*C,1,H,W,D,2) data are supported but x has shape {x.shape}")
-def complex_normalize(x: Tensor) -> Sequence: # type: ignore
+def complex_normalize(x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Performs layer mean-std normalization for complex data. Normalization is done for each batch member
along each part (part refers to real and imaginary parts), separately.
@@ -130,4 +159,149 @@ def complex_normalize(x: Tensor) -> Sequence: # type: ignore
.view(b, c, 1, 1, 1)
)
x = x.view(b, c, h, w, d)
- return (x - mean) / std, mean, std # type: ignore
+ return (x - mean) / std, mean, std
+
+ else:
+ raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
+
+
+def divisible_pad_t(
+ x: Tensor, k: int = 16
+) -> Tuple[Tensor, Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], int, int, int]]:
+ """
+ Pad input to feed into the network (torch script compatible)
+
+ Args:
+ x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
+ k: padding factor. each padded dimension will be divisible by k.
+
+ Returns:
+ A tuple containing
+ (1) padded input
+ (2) pad sizes (in order to reverse padding if needed)
+
+ Example:
+ .. code-block:: python
+
+ import torch
+
+ # 2D data
+ x = torch.ones([3,2,50,70])
+ x_pad,padding_sizes = divisible_pad_t(x, k=16)
+ # the following line should print (3, 2, 64, 80)
+ print(x_pad.shape)
+
+ # 3D data
+ x = torch.ones([3,2,50,70,80])
+ x_pad,padding_sizes = divisible_pad_t(x, k=16)
+ # the following line should print (3, 2, 64, 80, 80)
+ print(x_pad.shape)
+
+ """
+ if len(x.shape) == 4: # this is 2D
+ b, c, h, w = x.shape
+ w_mult = ((w - 1) | (k - 1)) + 1 # OR with (k-1) and then +1 makes sure padding is divisible by k
+ h_mult = ((h - 1) | (k - 1)) + 1
+ w_pad = floor_ceil((w_mult - w) / 2)
+ h_pad = floor_ceil((h_mult - h) / 2)
+ x = F.pad(x, w_pad + h_pad)
+ # dummy values for the 3rd spatial dimension
+ d_mult = -1
+ d_pad = (-1, -1)
+ pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)
+
+ elif len(x.shape) == 5: # this is 3D
+ b, c, h, w, d = x.shape
+ w_mult = ((w - 1) | (k - 1)) + 1
+ h_mult = ((h - 1) | (k - 1)) + 1
+ d_mult = ((d - 1) | (k - 1)) + 1
+ w_pad = floor_ceil((w_mult - w) / 2)
+ h_pad = floor_ceil((h_mult - h) / 2)
+ d_pad = floor_ceil((d_mult - d) / 2)
+ x = F.pad(x, d_pad + w_pad + h_pad)
+ pad_sizes = (h_pad, w_pad, d_pad, h_mult, w_mult, d_mult)
+
+ else:
+ raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
+
+ return x, pad_sizes
+
+
+def inverse_divisible_pad_t(
+ x: Tensor, pad_sizes: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], int, int, int]
+) -> Tensor:
+ """
+ De-pad network output to match its original shape
+
+ Args:
+ x: input of shape (B,C,H,W) for 2D data or (B,C,H,W,D) for 3D data
+ pad_sizes: padding values
+
+ Returns:
+ de-padded input
+ """
+ h_pad, w_pad, d_pad, h_mult, w_mult, d_mult = pad_sizes
+
+ if len(x.shape) == 4: # this is 2D
+ return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]]
+
+ elif len(x.shape) == 5: # this is 3D
+ return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1], d_pad[0] : d_mult - d_pad[1]]
+
+ else:
+ raise ValueError(f"only 2D (B,C,H,W) and 3D (B,C,H,W,D) data are supported but x has shape {x.shape}")
+
+
+def floor_ceil(n: float) -> Tuple[int, int]:
+ """
+ Returns floor and ceil of the input
+
+ Args:
+ n: input number
+
+ Returns:
+ A tuple containing:
+ (1) floor(n)
+ (2) ceil(n)
+ """
+ return math.floor(n), math.ceil(n)
+
+
+def sensitivity_map_reduce(kspace: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:
+ """
+ Reduces coil measurements to a corresponding image based on the given sens_maps. Let's say there
+ are C coil measurements inside kspace, then this function multiplies the conjugate of each coil sensitivity map with the
+ corresponding coil image. The result of this process will be C images. Summing those images together gives the
+ resulting "reduced image."
+
+ Args:
+ kspace: 2D kspace (B,C,H,W,2) with the last dimension being 2 (for real/imaginary parts) and C denoting the
+ coil dimension. 3D data will have the shape (B,C,H,W,D,2).
+ sens_maps: sensitivity maps of the same shape as input x.
+ spatial_dims: is 2 for 2D data and is 3 for 3D data
+
+ Returns:
+ reduction of x to (B,1,H,W,2) for 2D data or (B,1,H,W,D,2) for 3D data.
+ """
+ img = ifftn_centered_t(kspace, spatial_dims=spatial_dims, is_complex=True) # inverse fourier transform
+ return complex_mul_t(img, complex_conj_t(sens_maps)).sum(dim=1, keepdim=True)
+
+
+def sensitivity_map_expand(img: Tensor, sens_maps: Tensor, spatial_dims: int = 2) -> Tensor:
+ """
+ Expands an image to its corresponding coil images based on the given sens_maps. Let's say there
+ are C coils. This function multiples image img with each coil sensitivity map in sens_maps and stacks
+ the resulting C coil images along the channel dimension which is reserved for coils.
+
+ Args:
+ img: 2D image (B,1,H,W,2) with the last dimension being 2 (for real/imaginary parts). 3D data will have
+ the shape (B,1,H,W,D,2).
+ sens_maps: Sensitivity maps for combining coil images. The shape is (B,C,H,W,2) for 2D data
+ or (B,C,H,W,D,2) for 3D data (C denotes the coil dimension).
+ spatial_dims: is 2 for 2D data and is 3 for 3D data
+
+ Returns:
+ Expansion of x to (B,C,H,W,2) for 2D data and (B,C,H,W,D,2) for 3D data. The output is transferred
+ to the frequency domain to yield coil measurements.
+ """
+ return fftn_centered_t(complex_mul_t(img, sens_maps), spatial_dims=spatial_dims, is_complex=True)
diff --git a/monai/apps/reconstruction/networks/nets/varnet.py b/monai/apps/reconstruction/networks/nets/varnet.py
new file mode 100644
index 00000000000..33b93b3d821
--- /dev/null
+++ b/monai/apps/reconstruction/networks/nets/varnet.py
@@ -0,0 +1,77 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import torch.nn as nn
+from torch import Tensor
+
+from monai.apps.reconstruction.complex_utils import complex_abs_t
+from monai.apps.reconstruction.mri_utils import root_sum_of_squares_t
+from monai.apps.reconstruction.networks.blocks.varnetblock import VarNetBlock
+from monai.networks.blocks.fft_utils_t import ifftn_centered_t
+
+
+class VariationalNetworkModel(nn.Module):
+ """
+ The end-to-end variational network (or simply e2e-VarNet) based on Sriram et. al., "End-to-end variational
+ networks for accelerated MRI reconstruction".
+ It comprises several cascades each consisting of refinement and data consistency steps. The network takes in
+ the under-sampled kspace and estimates the ground-truth reconstruction.
+
+ Modified and adopted from: https://github.com/facebookresearch/fastMRI
+
+ Args:
+ coil_sensitivity_model: A convolutional model for learning coil sensitivity maps. An example is
+ :py:class:`monai.apps.reconstruction.networks.nets.coil_sensitivity_model.CoilSensitivityModel`.
+ refinement_model: A convolutional network used in the refinement step of e2e-VarNet. An example
+ is :py:class:`monai.apps.reconstruction.networks.nets.complex_unet.ComplexUnet`.
+ num_cascades: Number of cascades. Each cascade is a
+ :py:class:`monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock` which consists of
+ refinement and data consistency steps.
+ spatial_dims: number of spatial dimensions.
+ """
+
+ def __init__(
+ self,
+ coil_sensitivity_model: nn.Module,
+ refinement_model: nn.Module,
+ num_cascades: int = 12,
+ spatial_dims: int = 2,
+ ):
+ super().__init__()
+ self.coil_sensitivity_model = coil_sensitivity_model
+ self.cascades = nn.ModuleList([VarNetBlock(copy.deepcopy(refinement_model)) for i in range(num_cascades)])
+ self.spatial_dims = spatial_dims
+
+ def forward(self, masked_kspace: Tensor, mask: Tensor) -> Tensor:
+ """
+ Args:
+ masked_kspace: The under-sampled kspace. It's a 2D kspace (B,C,H,W,2)
+ with the last dimension being 2 (for real/imaginary parts) and C denoting the
+ coil dimension. 3D data will have the shape (B,C,H,W,D,2).
+ mask: The under-sampling mask with shape (1,1,1,W,1) for 2D data or (1,1,1,1,D,1) for 3D data.
+
+ Returns:
+ The reconstructed image which is the root sum of squares (rss) of the absolute value
+ of the inverse fourier of the predicted kspace (note that rss combines coil images into one image).
+ """
+ sensitivity_maps = self.coil_sensitivity_model(masked_kspace, mask) # shape is similar to masked_kspace
+ kspace_pred = masked_kspace.clone()
+
+ for cascade in self.cascades:
+ kspace_pred = cascade(kspace_pred, masked_kspace, mask, sensitivity_maps)
+
+ output_image = root_sum_of_squares_t(
+ complex_abs_t(ifftn_centered_t(kspace_pred, spatial_dims=self.spatial_dims)),
+ spatial_dim=1, # 1 is for C which is the coil dimension
+ ) # shape is (B,H,W) for 2D and (B,H,W,D) for 3D data.
+ return output_image
diff --git a/monai/apps/reconstruction/transforms/array.py b/monai/apps/reconstruction/transforms/array.py
index 660eab396b5..ed58439d29d 100644
--- a/monai/apps/reconstruction/transforms/array.py
+++ b/monai/apps/reconstruction/transforms/array.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from abc import abstractmethod
from typing import Sequence
diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py
index cf270b3a605..baa9bdb2ce8 100644
--- a/monai/apps/reconstruction/transforms/dictionary.py
+++ b/monai/apps/reconstruction/transforms/dictionary.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Dict, Hashable, Mapping, Optional, Sequence
import numpy as np
@@ -132,7 +131,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, T
d = dict(data)
for key in self.key_iterator(d):
d[key + "_masked"], d[key + "_masked_ifft"] = self.masker(d[key])
- d[FastMRIKeys.MASK] = self.masker.mask # type: ignore
+ d[FastMRIKeys.MASK] = self.masker.mask
+
return d # type: ignore
diff --git a/monai/apps/tcia/label_desc.py b/monai/apps/tcia/label_desc.py
index 582f83154a5..e3875e4095b 100644
--- a/monai/apps/tcia/label_desc.py
+++ b/monai/apps/tcia/label_desc.py
@@ -9,12 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Dict
__all__ = ["TCIA_LABEL_DICT"]
-
TCIA_LABEL_DICT: Dict[str, Dict[str, int]] = {
"C4KC-KiTS": {"Kidney": 0, "Renal Tumor": 1},
"NSCLC-Radiomics": {
diff --git a/monai/auto3dseg/__init__.py b/monai/auto3dseg/__init__.py
new file mode 100644
index 00000000000..9d350260451
--- /dev/null
+++ b/monai/auto3dseg/__init__.py
@@ -0,0 +1,35 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .algo_gen import Algo, AlgoGen
+from .analyzer import (
+ Analyzer,
+ FgImageStats,
+ FgImageStatsSumm,
+ FilenameStats,
+ ImageStats,
+ ImageStatsSumm,
+ LabelStats,
+ LabelStatsSumm,
+)
+from .operations import Operations, SampleOperations, SummaryOperations
+from .seg_summarizer import SegSummarizer
+from .utils import (
+ algo_from_pickle,
+ algo_to_pickle,
+ concat_multikeys_to_dict,
+ concat_val_to_np,
+ datafold_read,
+ get_foreground_image,
+ get_foreground_label,
+ get_label_ccp,
+ verify_report_format,
+)
diff --git a/monai/auto3dseg/algo_gen.py b/monai/auto3dseg/algo_gen.py
new file mode 100644
index 00000000000..1f8f0e11ede
--- /dev/null
+++ b/monai/auto3dseg/algo_gen.py
@@ -0,0 +1,112 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from monai.transforms import Randomizable
+
+
+class Algo:
+ """
+ An algorithm in this context is loosely defined as a data processing pipeline consisting of multiple components
+ such as image preprocessing, followed by deep learning model training and evaluation.
+ """
+
+ def set_data_stats(self, *args, **kwargs):
+ """Provide dataset (and summaries) so that the model creation can depend on the input datasets."""
+ pass
+
+ def train(self, params: dict):
+ """
+ Read training/validation data and output a model.
+
+ Args:
+ params: key-value pairs of input parameters for the training pipeline.
+ """
+ pass
+
+ def predict(self, params: dict):
+ """
+ Read test data and output model predictions.
+
+ Args:
+ params: key-value pairs of input parameters for the predicting pipeline.
+ """
+ pass
+
+ def get_score(self, *args, **kwargs):
+ """Returns the model quality measurement based on training and validation datasets."""
+ pass
+
+ def get_output_path(self, *args, **kwargs):
+ """Returns the algo output paths for scripts location"""
+ pass
+
+
+class AlgoGen(Randomizable):
+ """
+ A data-driven algorithm generator. It optionally takes the following inputs:
+
+ - training dataset properties (such as data statistics from ``monai.auto3dseg.analyzer``),
+ - previous algorithm's scores measuring the model quality,
+ - computational budgets,
+
+ and generates ``Algo`` instances. The generated algos are to be trained with the training datasets::
+
+ scores
+ +------------------------+
+ | +---------+ |
+ +-----------+ +-->| | +-----+----+
+ | Dataset, | | AlgoGen |--->| Algo |
+ | summaries |------>| | +----------+
+ +-----+-----+ +---------+ ^
+ | |
+ +----------------------------------+
+
+ This class also maintains a history of previously generated Algo and their corresponding validation scores.
+ The Algo generation process may be stochastic (using ``Randomizable.R`` as the source random state).
+ """
+
+ def set_data_stats(self, *args, **kwargs): # type ignore
+ """Provide dataset summaries/properties so that the generator can be conditioned on the input datasets."""
+ pass
+
+ def set_budget(self, *args, **kwargs):
+ """Provide computational budget so that the generator outputs algorithms that requires reasonable resources."""
+ pass
+
+ def set_score(self, *args, **kwargs):
+ """Feedback from the previously generated algo, the score can be used for new Algo generations."""
+ pass
+
+ def get_data_stats(self, *args, **kwargs):
+ """Get current dataset summaries."""
+ pass
+
+ def get_budget(self, *args, **kwargs):
+ """Get the current computational budget."""
+ pass
+
+ def get_history(self, *args, **kwargs):
+ """Get the previously generated algo."""
+ pass
+
+ def generate(self):
+ """Generate new Algo -- based on data_stats, budget, and history of previous algo generations."""
+ pass
+
+ def run_algo(self, *args, **kwargs):
+ """
+ Launch the Algos. This is useful for light-weight Algos where there's no need to distribute the training jobs.
+
+ If the generated Algos require significant scheduling of parallel executions, a job scheduler/controller
+ implemented separately is preferred to run them. In this case the controller should also report back the
+ scores and the algo history, so that the future ``AlgoGen.generate`` can leverage the information.
+ """
+ pass
diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py
new file mode 100644
index 00000000000..19cd95b9066
--- /dev/null
+++ b/monai/auto3dseg/analyzer.py
@@ -0,0 +1,1009 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from abc import ABC, abstractmethod
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+
+from monai.apps.utils import get_logger
+from monai.auto3dseg.operations import Operations, SampleOperations, SummaryOperations
+from monai.auto3dseg.utils import (
+ concat_multikeys_to_dict,
+ concat_val_to_np,
+ get_foreground_image,
+ get_foreground_label,
+ get_label_ccp,
+ verify_report_format,
+)
+from monai.bundle.config_parser import ConfigParser
+from monai.bundle.utils import ID_SEP_KEY
+from monai.data import MetaTensor, affine_to_spacing
+from monai.transforms.transform import MapTransform
+from monai.transforms.utils_pytorch_numpy_unification import sum, unique
+from monai.utils import convert_to_numpy
+from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
+from monai.utils.misc import ImageMetaKey, label_union
+
+logger = get_logger(module_name=__name__)
+
+__all__ = [
+ "Analyzer",
+ "ImageStats",
+ "FgImageStats",
+ "LabelStats",
+ "ImageStatsSumm",
+ "FgImageStatsSumm",
+ "LabelStatsSumm",
+ "FilenameStats",
+ "ImageHistogram",
+ "ImageHistogramSumm",
+]
+
+
+class Analyzer(MapTransform, ABC):
+ """
+ The Analyzer component is a base class. Other classes inherit this class will provide a callable
+ with the same class name and produces one pre-formatted dictionary for the input data. The format
+ is pre-defined by the init function of the class that inherit this base class. Function operations
+ can also be registered before the runtime of the callable.
+
+ Args:
+ report_format: a dictionary that outlines the key structures of the report format.
+
+ """
+
+ def __init__(self, stats_name: str, report_format: dict) -> None:
+ super().__init__(None)
+ parser = ConfigParser(report_format, globals=False) # ConfigParser.globals not picklable
+ self.report_format = parser.get("")
+ self.stats_name = stats_name
+ self.ops = ConfigParser({}, globals=False)
+
+ def update_ops(self, key: str, op):
+ """
+ Register a statistical operation to the Analyzer and update the report_format.
+
+ Args:
+ key: value key in the report.
+ op: Operation sub-class object that represents statistical operations.
+
+ """
+ self.ops[key] = op
+ parser = ConfigParser(self.report_format)
+
+ if parser.get(key, "None") != "None":
+ parser[key] = op
+
+ self.report_format = parser.get("")
+
+ def update_ops_nested_label(self, nested_key: str, op):
+ """
+ Update operations for nested label format. Operation value in report_format will be resolved
+ to a dict with only keys.
+
+ Args:
+ nested_key: str that has format of 'key1#0#key2'.
+ op: Operation sub-class object that represents statistical operations.
+ """
+ keys = nested_key.split(ID_SEP_KEY)
+ if len(keys) != 3:
+ raise ValueError("Nested_key input format is wrong. Please ensure it is like key1#0#key2")
+ root: str
+ child_key: str
+ (root, _, child_key) = keys
+ if root not in self.ops:
+ self.ops[root] = [{}]
+ self.ops[root][0].update({child_key: None})
+
+ self.ops[nested_key] = op
+
+ parser = ConfigParser(self.report_format)
+ if parser.get(nested_key, "NA") != "NA":
+ parser[nested_key] = op
+
+ def get_report_format(self):
+ """
+ Get the report format by resolving the registered operations recursively.
+
+ Returns:
+ a dictionary with {keys: None} pairs.
+
+ """
+ self.resolve_format(self.report_format)
+ return self.report_format
+
+ @staticmethod
+ def unwrap_ops(func):
+ """
+ Unwrap a function value and generates the same set keys in a dict when the function is actually
+ called in runtime
+
+ Args:
+ func: Operation sub-class object that represents statistical operations. The func object
+ should have a `data` dictionary which stores the statistical operation information.
+ For some operations (ImageStats for example), it may also contain the data_addon
+ property, which is part of the update process.
+
+ Returns:
+ a dict with a set of keys.
+
+ """
+ ret = dict.fromkeys(list(func.data))
+ if hasattr(func, "data_addon"):
+ for key in func.data_addon:
+ ret.update({key: None})
+ return ret
+
+ def resolve_format(self, report: dict):
+ """
+ Resolve the format of the pre-defined report.
+
+ Args:
+ report: the dictionary to resolve. Values will be replaced in-place.
+
+ """
+ for k, v in report.items():
+ if isinstance(v, Operations):
+ report[k] = self.unwrap_ops(v)
+ elif isinstance(v, list) and len(v) > 0:
+ self.resolve_format(v[0])
+ else:
+ report[k] = v
+
+ @abstractmethod
+ def __call__(self, data: Any):
+ """Analyze the dict format dataset, return the summary report"""
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+
+class ImageStats(Analyzer):
+ """
+ Analyzer to extract image stats properties for each case(image).
+
+ Args:
+ image_key: the key to find image data in the callable function input (data)
+
+ Examples:
+
+ .. code-block:: python
+
+ import numpy as np
+ from monai.auto3dseg import ImageStats
+ from monai.data import MetaTensor
+
+ input = {}
+ input['image'] = np.random.rand(1,30,30,30)
+ input['image'] = MetaTensor(np.random.rand(1,30,30,30)) # MetaTensor
+ analyzer = ImageStats(image_key="image")
+ print(analyzer(input)["image_stats"])
+
+ Notes:
+ if the image data is NumPy array, the spacing stats will be [1.0] * `ndims` of the array,
+ where the `ndims` is the lesser value between the image dimension and 3.
+
+ """
+
+ def __init__(self, image_key: str, stats_name: str = "image_stats") -> None:
+
+ if not isinstance(image_key, str):
+ raise ValueError("image_key input must be str")
+
+ self.image_key = image_key
+
+ report_format = {
+ ImageStatsKeys.SHAPE: None,
+ ImageStatsKeys.CHANNELS: None,
+ ImageStatsKeys.CROPPED_SHAPE: None,
+ ImageStatsKeys.SPACING: None,
+ ImageStatsKeys.INTENSITY: None,
+ }
+
+ super().__init__(stats_name, report_format)
+ self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
+
+ def __call__(self, data):
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format. The value of
+ ImageStatsKeys.INTENSITY is in a list format. Each element of the value list
+ has stats pre-defined by SampleOperations (max, min, ....).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Note:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+
+ """
+ d = dict(data)
+ start = time.time()
+ restore_grad_state = torch.is_grad_enabled()
+ torch.set_grad_enabled(False)
+
+ ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
+ if "nda_croppeds" not in d:
+ nda_croppeds = [get_foreground_image(nda) for nda in ndas]
+
+ # perform calculation
+ report = deepcopy(self.get_report_format())
+
+ report[ImageStatsKeys.SHAPE] = [list(nda.shape) for nda in ndas]
+ report[ImageStatsKeys.CHANNELS] = len(ndas)
+ report[ImageStatsKeys.CROPPED_SHAPE] = [list(nda_c.shape) for nda_c in nda_croppeds]
+ report[ImageStatsKeys.SPACING] = (
+ affine_to_spacing(data[self.image_key].affine).tolist()
+ if isinstance(data[self.image_key], MetaTensor)
+ else [1.0] * min(3, data[self.image_key].ndim)
+ )
+ report[ImageStatsKeys.INTENSITY] = [
+ self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_c) for nda_c in nda_croppeds
+ ]
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ d[self.stats_name] = report
+
+ torch.set_grad_enabled(restore_grad_state)
+ logger.debug(f"Get image stats spent {time.time()-start}")
+ return d
+
+
+class FgImageStats(Analyzer):
+ """
+ Analyzer to extract foreground label properties for each case(image and label).
+
+ Args:
+ image_key: the key to find image data in the callable function input (data)
+ label_key: the key to find label data in the callable function input (data)
+
+ Examples:
+
+ .. code-block:: python
+
+ import numpy as np
+ from monai.auto3dseg import FgImageStats
+
+ input = {}
+ input['image'] = np.random.rand(1,30,30,30)
+ input['label'] = np.ones([30,30,30])
+ analyzer = FgImageStats(image_key='image', label_key='label')
+ print(analyzer(input)["image_foreground_stats"])
+
+ """
+
+ def __init__(self, image_key: str, label_key: str, stats_name: str = "image_foreground_stats"):
+
+ self.image_key = image_key
+ self.label_key = label_key
+
+ report_format = {ImageStatsKeys.INTENSITY: None}
+
+ super().__init__(stats_name, report_format)
+ self.update_ops(ImageStatsKeys.INTENSITY, SampleOperations())
+
+ def __call__(self, data) -> dict:
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Note:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+
+ d = dict(data)
+ start = time.time()
+ restore_grad_state = torch.is_grad_enabled()
+ torch.set_grad_enabled(False)
+
+ ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
+ ndas_label = d[self.label_key] # (H,W,D)
+ nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
+
+ # perform calculation
+ report = deepcopy(self.get_report_format())
+
+ report[ImageStatsKeys.INTENSITY] = [
+ self.ops[ImageStatsKeys.INTENSITY].evaluate(nda_f) for nda_f in nda_foregrounds
+ ]
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ d[self.stats_name] = report
+
+ torch.set_grad_enabled(restore_grad_state)
+ logger.debug(f"Get foreground image stats spent {time.time()-start}")
+ return d
+
+
+class LabelStats(Analyzer):
+ """
+ Analyzer to extract label stats properties for each case(image and label).
+
+ Args:
+ image_key: the key to find image data in the callable function input (data)
+ label_key: the key to find label data in the callable function input (data)
+ do_ccp: performs connected component analysis. Default is True.
+
+ Examples:
+
+ .. code-block:: python
+
+ import numpy as np
+ from monai.auto3dseg import LabelStats
+
+ input = {}
+ input['image'] = np.random.rand(1,30,30,30)
+ input['label'] = np.ones([30,30,30])
+ analyzer = LabelStats(image_key='image', label_key='label')
+ print(analyzer(input)["label_stats"])
+
+ """
+
+ def __init__(self, image_key: str, label_key: str, stats_name: str = "label_stats", do_ccp: Optional[bool] = True):
+
+ self.image_key = image_key
+ self.label_key = label_key
+ self.do_ccp = do_ccp
+
+ report_format: Dict[str, Any] = {
+ LabelStatsKeys.LABEL_UID: None,
+ LabelStatsKeys.IMAGE_INTST: None,
+ LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}],
+ }
+
+ if self.do_ccp:
+ report_format[LabelStatsKeys.LABEL][0].update(
+ {LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}
+ )
+
+ super().__init__(stats_name, report_format)
+ self.update_ops(LabelStatsKeys.IMAGE_INTST, SampleOperations())
+
+ id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
+ self.update_ops_nested_label(id_seq, SampleOperations())
+
+ def __call__(self, data):
+ """
+ Callable to execute the pre-defined functions.
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....).
+
+ Examples:
+ output dict contains {
+ LabelStatsKeys.LABEL_UID:[0,1,3],
+ LabelStatsKeys.IMAGE_INTST: {...},
+ LabelStatsKeys.LABEL:[
+ {
+ LabelStatsKeys.PIXEL_PCT: 0.8,
+ LabelStatsKeys.IMAGE_INTST: {...},
+ LabelStatsKeys.LABEL_SHAPE: [...],
+ LabelStatsKeys.LABEL_NCOMP: 1
+ }
+ {
+ LabelStatsKeys.PIXEL_PCT: 0.1,
+ LabelStatsKeys.IMAGE_INTST: {...},
+ LabelStatsKeys.LABEL_SHAPE: [...],
+ LabelStatsKeys.LABEL_NCOMP: 1
+ }
+ {
+ LabelStatsKeys.PIXEL_PCT: 0.1,
+ LabelStatsKeys.IMAGE_INTST: {...},
+ LabelStatsKeys.LABEL_SHAPE: [...],
+ LabelStatsKeys.LABEL_NCOMP: 1
+ }
+ ]
+ }
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Notes:
+ The label class_ID of the dictionary in LabelStatsKeys.LABEL IS NOT the
+ index. Instead, the class_ID is the LabelStatsKeys.LABEL_UID with the same
+ index. For instance, the last dict in LabelStatsKeys.LABEL in the Examples
+ is 3, which is the last element under LabelStatsKeys.LABEL_UID.
+
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+ d = dict(data)
+ start = time.time()
+ if isinstance(d[self.image_key], (torch.Tensor, MetaTensor)) and d[self.image_key].device.type == "cuda":
+ using_cuda = True
+ else:
+ using_cuda = False
+ restore_grad_state = torch.is_grad_enabled()
+ torch.set_grad_enabled(False)
+
+ ndas = [d[self.image_key][i] for i in range(d[self.image_key].shape[0])]
+ ndas_label = d[self.label_key] # (H,W,D)
+ nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
+
+ unique_label = unique(ndas_label)
+ if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
+ unique_label = unique_label.data.cpu().numpy()
+
+ unique_label = unique_label.astype(np.int8).tolist()
+
+ label_substats = [] # each element is one label
+ pixel_sum = 0
+ pixel_arr = []
+ for index in unique_label:
+ start_label = time.time()
+ label_dict: Dict[str, Any] = {}
+ mask_index = ndas_label == index
+
+ nda_masks = [nda[mask_index] for nda in ndas]
+ label_dict[LabelStatsKeys.IMAGE_INTST] = [
+ self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_m) for nda_m in nda_masks
+ ]
+
+ pixel_count = sum(mask_index)
+ pixel_arr.append(pixel_count)
+ pixel_sum += pixel_count
+ if self.do_ccp: # apply connected component
+ if using_cuda:
+ # The back end of get_label_ccp is CuPy
+ # which is unable to automatically release CUDA GPU memory held by PyTorch
+ del nda_masks
+ torch.cuda.empty_cache()
+ shape_list, ncomponents = get_label_ccp(mask_index)
+ label_dict[LabelStatsKeys.LABEL_SHAPE] = shape_list
+ label_dict[LabelStatsKeys.LABEL_NCOMP] = ncomponents
+
+ label_substats.append(label_dict)
+ logger.debug(f" label {index} stats takes {time.time() - start_label}")
+
+ for i, _ in enumerate(unique_label):
+ label_substats[i].update({LabelStatsKeys.PIXEL_PCT: float(pixel_arr[i] / pixel_sum)})
+
+ report = deepcopy(self.get_report_format())
+ report[LabelStatsKeys.LABEL_UID] = unique_label
+ report[LabelStatsKeys.IMAGE_INTST] = [
+ self.ops[LabelStatsKeys.IMAGE_INTST].evaluate(nda_f) for nda_f in nda_foregrounds
+ ]
+ report[LabelStatsKeys.LABEL] = label_substats
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ d[self.stats_name] = report
+
+ torch.set_grad_enabled(restore_grad_state)
+ logger.debug(f"Get label stats spent {time.time()-start}")
+ return d
+
+
+class ImageStatsSumm(Analyzer):
+ """
+ This summary analyzer processes the values of specific key `stats_name` in a list of dict.
+ Typically, the list of dict is the output of case analyzer under the same prefix
+ (ImageStats).
+
+ Args:
+ stats_name: the key of the to-process value in the dict.
+ average: whether to average the statistical value across different image modalities.
+
+ """
+
+ def __init__(self, stats_name: str = "image_stats", average: Optional[bool] = True):
+ self.summary_average = average
+ report_format = {
+ ImageStatsKeys.SHAPE: None,
+ ImageStatsKeys.CHANNELS: None,
+ ImageStatsKeys.CROPPED_SHAPE: None,
+ ImageStatsKeys.SPACING: None,
+ ImageStatsKeys.INTENSITY: None,
+ }
+ super().__init__(stats_name, report_format)
+
+ self.update_ops(ImageStatsKeys.SHAPE, SampleOperations())
+ self.update_ops(ImageStatsKeys.CHANNELS, SampleOperations())
+ self.update_ops(ImageStatsKeys.CROPPED_SHAPE, SampleOperations())
+ self.update_ops(ImageStatsKeys.SPACING, SampleOperations())
+ self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations())
+
+ def __call__(self, data: List[Dict]):
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Examples:
+ output dict contains a dictionary for all of the following keys{
+ ImageStatsKeys.SHAPE:{...}
+ ImageStatsKeys.CHANNELS: {...},
+ ImageStatsKeys.CROPPED_SHAPE: {...},
+ ImageStatsKeys.SPACING: {...},
+ ImageStatsKeys.INTENSITY: {...},
+ }
+
+ Notes:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+ if not isinstance(data, list):
+ return ValueError(f"Callable {self.__class__} requires list inputs")
+
+ if len(data) == 0:
+ return ValueError(f"Callable {self.__class__} input list is empty")
+
+ if self.stats_name not in data[0]:
+ return KeyError(f"{self.stats_name} is not in input data")
+
+ report = deepcopy(self.get_report_format())
+
+ for k in [ImageStatsKeys.SHAPE, ImageStatsKeys.CHANNELS, ImageStatsKeys.CROPPED_SHAPE, ImageStatsKeys.SPACING]:
+ v_np = concat_val_to_np(data, [self.stats_name, k])
+ report[k] = self.ops[k].evaluate(v_np, dim=(0, 1) if v_np.ndim > 2 and self.summary_average else 0)
+
+ intst_str = ImageStatsKeys.INTENSITY
+ op_keys = report[intst_str].keys() # template, max/min/...
+ intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)
+ report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ return report
+
+
+class FgImageStatsSumm(Analyzer):
+ """
+ This summary analyzer processes the values of specific key `stats_name` in a list of
+ dict. Typically, the list of dict is the output of case analyzer under the similar name
+ (FgImageStats).
+
+ Args:
+ stats_name: the key of the to-process value in the dict.
+ average: whether to average the statistical value across different image modalities.
+
+ """
+
+ def __init__(self, stats_name: str = "image_foreground_stats", average: Optional[bool] = True):
+ self.summary_average = average
+
+ report_format = {ImageStatsKeys.INTENSITY: None}
+ super().__init__(stats_name, report_format)
+ self.update_ops(ImageStatsKeys.INTENSITY, SummaryOperations())
+
+ def __call__(self, data: List[Dict]):
+ """
+ Callable to execute the pre-defined functions.
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....) and SummaryOperation (max of the
+ max, mean of the mean, etc).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Examples:
+ output dict contains a dictionary for all of the following keys{
+ ImageStatsKeys.INTENSITY: {...},
+ }
+
+ Notes:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+ if not isinstance(data, list):
+ return ValueError(f"Callable {self.__class__} requires list inputs")
+
+ if len(data) == 0:
+ return ValueError(f"Callable {self.__class__} input list is empty")
+
+ if self.stats_name not in data[0]:
+ return KeyError(f"{self.stats_name} is not in input data.")
+
+ report = deepcopy(self.get_report_format())
+ intst_str = ImageStatsKeys.INTENSITY
+ op_keys = report[intst_str].keys() # template, max/min/...
+ intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)
+
+ report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ return report
+
+
+class LabelStatsSumm(Analyzer):
+ """
+ This summary analyzer processes the values of specific key `stats_name` in a list of
+ dict. Typically, the list of dict is the output of case analyzer under the similar name
+ (LabelStats).
+
+ Args:
+ stats_name: the key of the to-process value in the dict.
+ average: whether to average the statistical value across different image modalities.
+
+ """
+
+ def __init__(self, stats_name: str = "label_stats", average: Optional[bool] = True, do_ccp: Optional[bool] = True):
+ self.summary_average = average
+ self.do_ccp = do_ccp
+
+ report_format: Dict[str, Any] = {
+ LabelStatsKeys.LABEL_UID: None,
+ LabelStatsKeys.IMAGE_INTST: None,
+ LabelStatsKeys.LABEL: [{LabelStatsKeys.PIXEL_PCT: None, LabelStatsKeys.IMAGE_INTST: None}],
+ }
+ if self.do_ccp:
+ report_format[LabelStatsKeys.LABEL][0].update(
+ {LabelStatsKeys.LABEL_SHAPE: None, LabelStatsKeys.LABEL_NCOMP: None}
+ )
+
+ super().__init__(stats_name, report_format)
+ self.update_ops(LabelStatsKeys.IMAGE_INTST, SummaryOperations())
+
+ # label-0-'pixel percentage'
+ id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.PIXEL_PCT])
+ self.update_ops_nested_label(id_seq, SampleOperations())
+ # label-0-'image intensity'
+ id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.IMAGE_INTST])
+ self.update_ops_nested_label(id_seq, SummaryOperations())
+ # label-0-shape
+ id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.LABEL_SHAPE])
+ self.update_ops_nested_label(id_seq, SampleOperations())
+ # label-0-ncomponents
+ id_seq = ID_SEP_KEY.join([LabelStatsKeys.LABEL, "0", LabelStatsKeys.LABEL_NCOMP])
+ self.update_ops_nested_label(id_seq, SampleOperations())
+
+ def __call__(self, data: List[Dict]):
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....) and SummaryOperation (max of the
+ max, mean of the mean, etc).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Notes:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+ if not isinstance(data, list):
+ return ValueError(f"Callable {self.__class__} requires list inputs")
+
+ if len(data) == 0:
+ return ValueError(f"Callable {self.__class__} input list is empty")
+
+ if self.stats_name not in data[0]:
+ return KeyError(f"{self.stats_name} is not in input data")
+
+ report = deepcopy(self.get_report_format())
+ # unique class ID
+ uid_np = concat_val_to_np(data, [self.stats_name, LabelStatsKeys.LABEL_UID], axis=None, ragged=True)
+ unique_label = label_union(uid_np)
+ report[LabelStatsKeys.LABEL_UID] = unique_label
+
+ # image intensity
+ intst_str = LabelStatsKeys.IMAGE_INTST
+ op_keys = report[intst_str].keys() # template, max/min/...
+ intst_dict = concat_multikeys_to_dict(data, [self.stats_name, intst_str], op_keys)
+ report[intst_str] = self.ops[intst_str].evaluate(intst_dict, dim=None if self.summary_average else 0)
+
+ detailed_label_list = []
+ # iterate through each label
+ label_str = LabelStatsKeys.LABEL
+ for label_id in unique_label:
+ stats = {}
+
+ pct_str = LabelStatsKeys.PIXEL_PCT
+ pct_fixed_keys = [self.stats_name, label_str, label_id, pct_str]
+ pct_np = concat_val_to_np(data, pct_fixed_keys, allow_missing=True)
+ stats[pct_str] = self.ops[label_str][0][pct_str].evaluate(
+ pct_np, dim=(0, 1) if pct_np.ndim > 2 and self.summary_average else 0
+ )
+
+ if self.do_ccp:
+ ncomp_str = LabelStatsKeys.LABEL_NCOMP
+ ncomp_fixed_keys = [self.stats_name, LabelStatsKeys.LABEL, label_id, ncomp_str]
+ ncomp_np = concat_val_to_np(data, ncomp_fixed_keys, allow_missing=True)
+ stats[ncomp_str] = self.ops[label_str][0][ncomp_str].evaluate(
+ ncomp_np, dim=(0, 1) if ncomp_np.ndim > 2 and self.summary_average else 0
+ )
+
+ shape_str = LabelStatsKeys.LABEL_SHAPE
+ shape_fixed_keys = [self.stats_name, label_str, label_id, LabelStatsKeys.LABEL_SHAPE]
+ shape_np = concat_val_to_np(data, shape_fixed_keys, ragged=True, allow_missing=True)
+ stats[shape_str] = self.ops[label_str][0][shape_str].evaluate(
+ shape_np, dim=(0, 1) if shape_np.ndim > 2 and self.summary_average else 0
+ )
+ # label shape is a 3-element value, but the number of labels in each image
+ # can vary from 0 to N. So the value in a list format is "ragged"
+
+ intst_str = LabelStatsKeys.IMAGE_INTST
+ intst_fixed_keys = [self.stats_name, label_str, label_id, intst_str]
+ op_keys = report[label_str][0][intst_str].keys()
+ intst_dict = concat_multikeys_to_dict(data, intst_fixed_keys, op_keys, allow_missing=True)
+ stats[intst_str] = self.ops[label_str][0][intst_str].evaluate(
+ intst_dict, dim=None if self.summary_average else 0
+ )
+
+ detailed_label_list.append(stats)
+
+ report[LabelStatsKeys.LABEL] = detailed_label_list
+
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ return report
+
+
+class FilenameStats(Analyzer):
+ """
+ This class finds the file path for the loaded image/label and writes the info
+ into the data pipeline as a monai transforms.
+
+ Args:
+ key: the key to fetch the filename (for example, "image", "label").
+ stats_name: the key to store the filename in the output stats report.
+
+ """
+
+ def __init__(self, key: str, stats_name: str) -> None:
+ self.key = key
+ super().__init__(stats_name, {})
+
+ def __call__(self, data):
+ d = dict(data)
+
+ if self.key: # when there is no (label) file, key can be None
+ if self.key not in d: # check whether image/label is in the data
+ raise ValueError(f"Data with key {self.key} is missing.")
+ if not isinstance(d[self.key], MetaTensor):
+ raise ValueError(f"Value type of {self.key} is not MetaTensor.")
+ if ImageMetaKey.FILENAME_OR_OBJ not in d[self.key].meta:
+ raise ValueError(f"{ImageMetaKey.FILENAME_OR_OBJ} not found in MetaTensor {d[self.key]}.")
+ d[self.stats_name] = d[self.key].meta[ImageMetaKey.FILENAME_OR_OBJ]
+ else:
+ d[self.stats_name] = "None"
+
+ return d
+
+
+class ImageHistogram(Analyzer):
+ """
+ Analyzer to compute intensity histogram.
+
+ Args:
+ image_key: the key to find image data in the callable function input (data)
+ hist_bins: list of positive integers (one for each channel) for setting the number of bins used to
+ compute the histogram. Defaults to [100].
+ hist_range: list of lists of two floats (one for each channel) setting the intensity range to
+ compute the histogram. Defaults to [-500, 500].
+
+ Examples:
+
+ .. code-block:: python
+
+ import numpy as np
+ from monai.auto3dseg.analyzer import ImageHistogram
+
+ input = {}
+ input['image'] = np.random.rand(1,30,30,30)
+ input['label'] = np.ones([30,30,30])
+ analyzer = ImageHistogram(image_key='image')
+ print(analyzer(input))
+
+ """
+
+ def __init__(
+ self,
+ image_key: str,
+ stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM,
+ hist_bins: Optional[list] = None,
+ hist_range: Optional[list] = None,
+ ):
+
+ self.image_key = image_key
+
+ # set defaults
+ self.hist_bins: list = [100] if hist_bins is None else hist_bins
+ self.hist_range: list = [-500, 500] if hist_range is None else hist_range
+
+ report_format = {"counts": None, "bin_edges": None}
+
+ super().__init__(stats_name, report_format)
+ self.update_ops(ImageStatsKeys.HISTOGRAM, SampleOperations())
+
+ # check histogram configurations for each channel in list
+ if not isinstance(self.hist_bins, list):
+ self.hist_bins = [self.hist_bins]
+ if not all(isinstance(hr, list) for hr in self.hist_range):
+ self.hist_range = [self.hist_range]
+ if len(self.hist_bins) != len(self.hist_range):
+ raise ValueError(
+ f"Number of histogram bins ({len(self.hist_bins)}) and "
+ f"histogram ranges ({len(self.hist_range)}) need to be the same!"
+ )
+ for i, hist_params in enumerate(zip(self.hist_bins, self.hist_range)):
+ _hist_bins, _hist_range = hist_params
+ if not isinstance(_hist_bins, int) or _hist_bins < 0:
+ raise ValueError(f"Expected {i+1}. hist_bins value to be positive integer but got {_hist_bins}")
+ if not isinstance(_hist_range, list) or len(_hist_range) != 2:
+ raise ValueError(f"Expected {i+1}. hist_range values to be list of length 2 but received {_hist_range}")
+
+ def __call__(self, data) -> dict:
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Note:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+
+ d = dict(data)
+
+ ndas = convert_to_numpy(d[self.image_key], wrap_sequence=True) # (1,H,W,D) or (C,H,W,D)
+ nr_channels = np.shape(ndas)[0]
+
+ # adjust histogram params to match channels
+ if len(self.hist_bins) == 1:
+ self.hist_bins = nr_channels * self.hist_bins
+ if len(self.hist_bins) != nr_channels:
+ raise ValueError(
+ f"There is a mismatch between the number of channels ({nr_channels}) "
+ f"and number histogram bins ({len(self.hist_bins)})."
+ )
+ if len(self.hist_range) == 1:
+ self.hist_range = nr_channels * self.hist_range
+ if len(self.hist_range) != nr_channels:
+ raise ValueError(
+ f"There is a mismatch between the number of channels ({nr_channels}) "
+ f"and histogram ranges ({len(self.hist_range)})."
+ )
+
+ # perform calculation
+ reports = []
+ for channel in range(nr_channels):
+ counts, bin_edges = np.histogram(
+ ndas[channel, ...],
+ bins=self.hist_bins[channel],
+ range=(self.hist_range[channel][0], self.hist_range[channel][1]),
+ )
+ _report = {"counts": counts.tolist(), "bin_edges": bin_edges.tolist()}
+ if not verify_report_format(_report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+ reports.append(_report)
+
+ d[self.stats_name] = reports
+ return d
+
+
+class ImageHistogramSumm(Analyzer):
+ """
+ This summary analyzer processes the values of specific key `stats_name` in a list of dict.
+ Typically, the list of dict is the output of case analyzer under the same prefix
+ (ImageHistogram).
+
+ Args:
+ stats_name: the key of the to-process value in the dict.
+ average: whether to average the statistical value across different image modalities.
+
+ """
+
+ def __init__(self, stats_name: str = DataStatsKeys.IMAGE_HISTOGRAM, average: Optional[bool] = True):
+ self.summary_average = average
+ report_format = {ImageStatsKeys.HISTOGRAM: None}
+ super().__init__(stats_name, report_format)
+
+ self.update_ops(ImageStatsKeys.HISTOGRAM, SummaryOperations())
+
+ def __call__(self, data: List[Dict]):
+ """
+ Callable to execute the pre-defined functions
+
+ Returns:
+ A dictionary. The dict has the key in self.report_format and value
+ in a list format. Each element of the value list has stats pre-defined
+ by SampleOperations (max, min, ....).
+
+ Raises:
+ RuntimeError if the stats report generated is not consistent with the pre-
+ defined report_format.
+
+ Examples:
+ output dict contains a dictionary for all of the following keys{
+ ImageStatsKeys.SHAPE:{...}
+ ImageStatsKeys.CHANNELS: {...},
+ ImageStatsKeys.CROPPED_SHAPE: {...},
+ ImageStatsKeys.SPACING: {...},
+ ImageStatsKeys.INTENSITY: {...},
+ }
+
+ Notes:
+ The stats operation uses numpy and torch to compute max, min, and other
+ functions. If the input has nan/inf, the stats results will be nan/inf.
+ """
+ if not isinstance(data, list):
+ return ValueError(f"Callable {self.__class__} requires list inputs")
+
+ if len(data) == 0:
+ return ValueError(f"Callable {self.__class__} input list is empty")
+
+ if self.stats_name not in data[0]:
+ return KeyError(f"{self.stats_name} is not in input data")
+
+ summ_histogram: Dict = {}
+
+ for d in data:
+ if not summ_histogram:
+ summ_histogram = d[DataStatsKeys.IMAGE_HISTOGRAM]
+ # convert to numpy for computing total histogram
+ for k in range(len(summ_histogram)):
+ summ_histogram[k]["counts"] = np.array(summ_histogram[k]["counts"])
+ else:
+ for k in range(len(summ_histogram)):
+ summ_histogram[k]["counts"] += np.array(d[DataStatsKeys.IMAGE_HISTOGRAM][k]["counts"])
+ if np.all(summ_histogram[k]["bin_edges"] != d[DataStatsKeys.IMAGE_HISTOGRAM][k]["bin_edges"]):
+ raise ValueError(
+ f"bin edges are not consistent! {summ_histogram[k]['bin_edges']} vs. "
+ f"{d[DataStatsKeys.IMAGE_HISTOGRAM][k]['bin_edges']}"
+ )
+
+ # convert back to list
+ for k in range(len(summ_histogram)):
+ summ_histogram[k]["counts"] = summ_histogram[k]["counts"].tolist()
+
+ report = {ImageStatsKeys.HISTOGRAM: summ_histogram}
+ if not verify_report_format(report, self.get_report_format()):
+ raise RuntimeError(f"report generated by {self.__class__} differs from the report format.")
+
+ return report
diff --git a/monai/auto3dseg/operations.py b/monai/auto3dseg/operations.py
new file mode 100644
index 00000000000..45294549ef3
--- /dev/null
+++ b/monai/auto3dseg/operations.py
@@ -0,0 +1,150 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import UserDict
+from functools import partial
+from typing import Any
+
+from monai.transforms.utils_pytorch_numpy_unification import max, mean, median, min, percentile, std
+
+__all__ = ["Operations", "SampleOperations", "SummaryOperations"]
+
+
+class Operations(UserDict):
+ """
+ Base class of operation interface
+ """
+
+ def evaluate(self, data: Any, **kwargs) -> dict:
+ """
+ For key-value pairs in the self.data, if the value is a callable,
+ then this function will apply the callable to the input data.
+ The result will be written under the same key under the output dict.
+
+ Args:
+ data: input data.
+
+ Returns:
+ a dictionary which has same keys as the self.data if the value
+ is callable.
+ """
+ return {k: v(data, **kwargs) for k, v in self.data.items() if callable(v)}
+
+
+class SampleOperations(Operations):
+ """
+ Apply statistical operation to a sample (image/ndarray/tensor).
+
+ Notes:
+ Percentile operation uses a partial function that embeds different kwargs (q).
+ In order to print the result nicely, data_addon is added to map the numbers
+ generated by percentile to different keys ("percentile_00_5" for example).
+ Annotation of the postfix means the percentage for percentile computation.
+ For example, _00_5 means 0.5% and _99_5 means 99.5%.
+
+ Example:
+
+ .. code-block:: python
+
+ # use the existing operations
+ import numpy as np
+ op = SampleOperations()
+ data_np = np.random.rand(10, 10).astype(np.float64)
+ print(op.evaluate(data_np))
+
+ # add a new operation
+ op.update({"sum": np.sum})
+ print(op.evaluate(data_np))
+ """
+
+ def __init__(self) -> None:
+ self.data = {
+ "max": max,
+ "mean": mean,
+ "median": median,
+ "min": min,
+ "stdev": std,
+ "percentile": partial(percentile, q=[0.5, 10, 90, 99.5]),
+ }
+ self.data_addon = {
+ "percentile_00_5": ("percentile", 0),
+ "percentile_10_0": ("percentile", 1),
+ "percentile_90_0": ("percentile", 2),
+ "percentile_99_5": ("percentile", 3),
+ }
+
+ def evaluate(self, data: Any, **kwargs) -> dict:
+ """
+ Applies the callables to the data, and convert the
+ numerics to list or Python numeric types (int/float).
+
+ Args:
+ data: input data
+ """
+ ret = super().evaluate(data, **kwargs)
+ for k, v in self.data_addon.items():
+ cache = v[0]
+ idx = v[1]
+ if isinstance(v, tuple) and cache in ret:
+ ret.update({k: ret[cache][idx]})
+
+ for k, v in ret.items():
+ ret[k] = v.tolist() # type: ignore
+ return ret
+
+
+class SummaryOperations(Operations):
+ """
+ Apply statistical operation to summarize a dict. The key-value looks like: {"max", "min"
+ ,"mean", ....}. The value may contain multiple values in a list format. Then this operation
+ will apply the operation to the list. Typically, the dict is generated by multiple
+ `SampleOperation` and `concat_multikeys_to_dict` functions.
+
+ Examples:
+
+ .. code-block:: python
+
+ import numpy as np
+ data = {
+ "min": np.random.rand(4),
+ "max": np.random.rand(4),
+ "mean": np.random.rand(4),
+ "sum": np.random.rand(4),
+ }
+ op = SummaryOperations()
+ print(op.evaluate(data)) # "sum" is not registered yet, so it won't contain "sum"
+
+ op.update({"sum", np.sum})
+ print(op.evaluate(data)) # output has "sum"
+ """
+
+ def __init__(self) -> None:
+ self.data = {
+ "max": max,
+ "mean": mean,
+ "median": mean,
+ "min": min,
+ "stdev": mean,
+ "percentile_00_5": mean,
+ "percentile_10_0": mean,
+ "percentile_90_0": mean,
+ "percentile_99_5": mean,
+ }
+
+ def evaluate(self, data: Any, **kwargs) -> dict:
+ """
+ Applies the callables to the data, and convert the numerics to list or Python
+ numeric types (int/float).
+
+ Args:
+ data: input data
+ """
+ return {k: v(data[k], **kwargs).tolist() for k, v in self.data.items() if (callable(v) and k in data)}
diff --git a/monai/auto3dseg/seg_summarizer.py b/monai/auto3dseg/seg_summarizer.py
new file mode 100644
index 00000000000..e158068d4e6
--- /dev/null
+++ b/monai/auto3dseg/seg_summarizer.py
@@ -0,0 +1,210 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional
+
+from monai.auto3dseg.analyzer import (
+ FgImageStats,
+ FgImageStatsSumm,
+ FilenameStats,
+ ImageHistogram,
+ ImageHistogramSumm,
+ ImageStats,
+ ImageStatsSumm,
+ LabelStats,
+ LabelStatsSumm,
+)
+from monai.transforms import Compose
+from monai.utils.enums import DataStatsKeys
+
+__all__ = ["SegSummarizer"]
+
+
+class SegSummarizer(Compose):
+ """
+ SegSummarizer serializes the operations for data analysis in Auto3Dseg pipeline. It loads
+ two types of analyzer functions and execute differently. The first type of analyzer is
+ CaseAnalyzer which is similar to traditional monai transforms. It can be composed with other
+ transforms to process the data dict which has image/label keys. The second type of analyzer
+ is SummaryAnalyzer which works only on a list of dictionary. Each dictionary is the output
+ of the case analyzers on a single dataset.
+
+ Args:
+ image_key: a string that user specify for the image. The DataAnalyzer will look it up in the
+ datalist to locate the image files of the dataset.
+ label_key: a string that user specify for the label. The DataAnalyzer will look it up in the
+ datalist to locate the label files of the dataset. If label_key is None, the DataAnalyzer
+ will skip looking for labels and all label-related operations.
+ do_ccp: apply the connected component algorithm to process the labels/images.
+ hist_bins: list of positive integers (one for each channel) for setting the number of bins used to
+ compute the histogram. Defaults to [100].
+ hist_range: list of lists of two floats (one for each channel) setting the intensity range to
+ compute the histogram. Defaults to [-500, 500].
+ histogram_only: whether to only compute histograms. Defaults to False.
+
+ Examples:
+ .. code-block:: python
+
+ # imports
+
+ summarizer = SegSummarizer("image", "label")
+ transform_list = [
+ LoadImaged(keys=keys),
+ EnsureChannelFirstd(keys=keys), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=keys, device=device, non_blocking=True),
+ Orientationd(keys=keys, axcodes="RAS"),
+ EnsureTyped(keys=keys, data_type="tensor"),
+ Lambdad(keys="label", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ summarizer,
+ ]
+ ...
+ # skip some steps to set up data loader
+ dataset = data.DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ transform = Compose(transform_list)
+ stats = []
+ for batch_data in dataset:
+ d = transform(batch_data[0])
+ stats.append(d)
+ report = summarizer.summarize(stats)
+ """
+
+ def __init__(
+ self,
+ image_key: str,
+ label_key: str,
+ average=True,
+ do_ccp: bool = True,
+ hist_bins: Optional[list] = None,
+ hist_range: Optional[list] = None,
+ histogram_only: bool = False,
+ ) -> None:
+
+ self.image_key = image_key
+ self.label_key = label_key
+ # set defaults
+ self.hist_bins: list = [100] if hist_bins is None else hist_bins
+ self.hist_range: list = [-500, 500] if hist_range is None else hist_range
+ self.histogram_only = histogram_only
+
+ self.summary_analyzers: List[Any] = []
+ super().__init__()
+
+ if not self.histogram_only:
+ self.add_analyzer(FilenameStats(image_key, DataStatsKeys.BY_CASE_IMAGE_PATH), None)
+ self.add_analyzer(FilenameStats(label_key, DataStatsKeys.BY_CASE_LABEL_PATH), None)
+ self.add_analyzer(ImageStats(image_key), ImageStatsSumm(average=average))
+
+ if label_key is None:
+ return
+
+ self.add_analyzer(FgImageStats(image_key, label_key), FgImageStatsSumm(average=average))
+
+ self.add_analyzer(
+ LabelStats(image_key, label_key, do_ccp=do_ccp), LabelStatsSumm(average=average, do_ccp=do_ccp)
+ )
+
+ # compute histograms
+ if self.hist_bins != 0: # type: ignore
+ self.add_analyzer(
+ ImageHistogram(image_key=image_key, hist_bins=hist_bins, hist_range=hist_range), ImageHistogramSumm()
+ )
+
+ def add_analyzer(self, case_analyzer, summary_analyzer) -> None:
+ """
+ Add new analyzers to the engine so that the callable and summarize functions will
+ utilize the new analyzers for stats computations.
+
+ Args:
+ case_analyzer: analyzer that works on each data.
+ summary_analyzer: analyzer that works on list of stats dict (output from case_analyzers).
+
+ Examples:
+
+ .. code-block:: python
+
+ from monai.auto3dseg import Analyzer
+ from monai.auto3dseg.utils import concat_val_to_np
+ from monai.auto3dseg.analyzer_engine import SegSummarizer
+
+ class UserAnalyzer(Analyzer):
+ def __init__(self, image_key="image", stats_name="user_stats"):
+ self.image_key = image_key
+ report_format = {"ndims": None}
+ super().__init__(stats_name, report_format)
+
+ def __call__(self, data):
+ d = dict(data)
+ report = deepcopy(self.get_report_format())
+ report["ndims"] = d[self.image_key].ndim
+ d[self.stats_name] = report
+ return d
+
+ class UserSummaryAnalyzer(Analyzer):
+ def __init__(stats_name="user_stats"):
+ report_format = {"ndims": None}
+ super().__init__(stats_name, report_format)
+ self.update_ops("ndims", SampleOperations())
+
+ def __call__(self, data):
+ report = deepcopy(self.get_report_format())
+ v_np = concat_val_to_np(data, [self.stats_name, "ndims"])
+ report["ndims"] = self.ops["ndims"].evaluate(v_np)
+ return report
+
+ summarizer = SegSummarizer()
+ summarizer.add_analyzer(UserAnalyzer, UserSummaryAnalyzer)
+
+ """
+ self.transforms += (case_analyzer,)
+ self.summary_analyzers.append(summary_analyzer)
+
+ def summarize(self, data: List[Dict]):
+ """
+ Summarize the input list of data and generates a report ready for json/yaml export.
+
+ Args:
+ data: a list of data dicts.
+
+ Returns:
+ a dict that summarizes the stats across data samples.
+
+ Examples:
+ stats_summary:
+ image_foreground_stats:
+ intensity: {...}
+ image_stats:
+ channels: {...}
+ cropped_shape: {...}
+ ...
+ label_stats:
+ image_intensity: {...}
+ label:
+ - image_intensity: {...}
+ - image_intensity: {...}
+ - image_intensity: {...}
+ - image_intensity: {...}
+ """
+ if not isinstance(data, list):
+ raise ValueError(f"{self.__class__} summarize function needs input to be a list of dict")
+
+ report: Dict[str, Dict] = {}
+ if len(data) == 0:
+ return report
+
+ if not isinstance(data[0], dict):
+ raise ValueError(f"{self.__class__} summarize function needs a list of dict. Now we have {type(data[0])}")
+
+ for analyzer in self.summary_analyzers:
+ if callable(analyzer):
+ report.update({analyzer.stats_name: analyzer(data)})
+
+ return report
diff --git a/monai/auto3dseg/utils.py b/monai/auto3dseg/utils.py
new file mode 100644
index 00000000000..78593f83693
--- /dev/null
+++ b/monai/auto3dseg/utils.py
@@ -0,0 +1,352 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pickle
+import sys
+import warnings
+from copy import deepcopy
+from numbers import Number
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
+
+import numpy as np
+import torch
+
+from monai.auto3dseg import Algo
+from monai.bundle.config_parser import ConfigParser
+from monai.bundle.utils import ID_SEP_KEY
+from monai.data.meta_tensor import MetaTensor
+from monai.transforms import CropForeground, ToCupy
+from monai.utils import min_version, optional_import
+
+__all__ = [
+ "get_foreground_image",
+ "get_foreground_label",
+ "get_label_ccp",
+ "concat_val_to_np",
+ "concat_multikeys_to_dict",
+ "datafold_read",
+ "verify_report_format",
+ "algo_to_pickle",
+ "algo_from_pickle",
+]
+
+measure_np, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
+cp, has_cp = optional_import("cupy")
+cucim, has_cucim = optional_import("cucim")
+
+
+def get_foreground_image(image: MetaTensor):
+ """
+ Get a foreground image by removing all-zero rectangles on the edges of the image
+ Note for the developer: update select_fn if the foreground is defined differently.
+
+ Args:
+ image: ndarray image to segment.
+
+ Returns:
+ ndarray of foreground image by removing all-zero edges.
+
+ Notes:
+ the size of the output is smaller than the input.
+ """
+
+ copper = CropForeground(select_fn=lambda x: x > 0)
+ image_foreground = copper(image)
+ return image_foreground
+
+
+def get_foreground_label(image: MetaTensor, label: MetaTensor) -> MetaTensor:
+ """
+ Get foreground image pixel values and mask out the non-labeled area.
+
+ Args
+ image: ndarray image to segment.
+ label: ndarray the image input and annotated with class IDs.
+
+ Returns:
+ 1D array of foreground image with label > 0
+ """
+
+ label_foreground = MetaTensor(image[label > 0])
+ return label_foreground
+
+
+def get_label_ccp(mask_index: MetaTensor, use_gpu: bool = True) -> Tuple[List[Any], int]:
+ """
+ Find all connected components and their bounding shape. Backend can be cuPy/cuCIM or Numpy
+ depending on the hardware.
+
+ Args:
+ mask_index: a binary mask.
+ use_gpu: a switch to use GPU/CUDA or not. If GPU is unavailable, CPU will be used
+ regardless of this setting.
+
+ """
+
+ shape_list = []
+ if mask_index.device.type == "cuda" and has_cp and has_cucim and use_gpu:
+ mask_cupy = ToCupy()(mask_index.short())
+ labeled = cucim.skimage.measure.label(mask_cupy)
+ vals = cp.unique(labeled[cp.nonzero(labeled)])
+
+ for ncomp in vals:
+ comp_idx = cp.argwhere(labeled == ncomp)
+ comp_idx_min = cp.min(comp_idx, axis=0).tolist()
+ comp_idx_max = cp.max(comp_idx, axis=0).tolist()
+ bbox_shape = [comp_idx_max[i] - comp_idx_min[i] + 1 for i in range(len(comp_idx_max))]
+ shape_list.append(bbox_shape)
+ ncomponents = len(vals)
+
+ del mask_cupy, labeled, vals, comp_idx, ncomp
+ cp.get_default_memory_pool().free_all_blocks()
+
+ elif has_measure:
+ labeled, ncomponents = measure_np.label(mask_index.data.cpu().numpy(), background=-1, return_num=True)
+ for ncomp in range(1, ncomponents + 1):
+ comp_idx = np.argwhere(labeled == ncomp)
+ comp_idx_min = np.min(comp_idx, axis=0).tolist()
+ comp_idx_max = np.max(comp_idx, axis=0).tolist()
+ bbox_shape = [comp_idx_max[i] - comp_idx_min[i] + 1 for i in range(len(comp_idx_max))]
+ shape_list.append(bbox_shape)
+ else:
+ raise RuntimeError("Cannot find one of the following required dependencies: {cuPy+cuCIM} or {scikit-image}")
+
+ return shape_list, ncomponents
+
+
+def concat_val_to_np(
+ data_list: List[Dict],
+ fixed_keys: List[Union[str, int]],
+ ragged: Optional[bool] = False,
+ allow_missing: Optional[bool] = False,
+ **kwargs,
+):
+ """
+ Get the nested value in a list of dictionary that shares the same structure.
+
+ Args:
+ data_list: a list of dictionary {key1: {key2: np.ndarray}}.
+ fixed_keys: a list of keys that records to path to the value in the dict elements.
+ ragged: if True, numbers can be in list of lists or ragged format so concat mode needs change.
+ allow_missing: if True, it will return a None if the value cannot be found.
+
+ Returns:
+ nd.array of concatenated array.
+
+ """
+
+ np_list: List[Optional[np.ndarray]] = []
+ for data in data_list:
+ parser = ConfigParser(data)
+ for i, key in enumerate(fixed_keys):
+ fixed_keys[i] = str(key)
+
+ val: Any
+ val = parser.get(ID_SEP_KEY.join(cast(Iterable[str], fixed_keys)))
+
+ if val is None:
+ if allow_missing:
+ np_list.append(None)
+ else:
+ raise AttributeError(f"{fixed_keys} is not nested in the dictionary")
+ elif isinstance(val, list):
+ np_list.append(np.array(val))
+ elif isinstance(val, (torch.Tensor, MetaTensor)):
+ np_list.append(val.cpu().numpy())
+ elif isinstance(val, np.ndarray):
+ np_list.append(val)
+ elif isinstance(val, Number):
+ np_list.append(np.array(val))
+ else:
+ raise NotImplementedError(f"{val.__class__} concat is not supported.")
+
+ if allow_missing:
+ np_list = [x for x in np_list if x is not None]
+
+ if len(np_list) == 0:
+ return np.array([0])
+ elif ragged:
+ return np.concatenate(np_list, **kwargs) # type: ignore
+ else:
+ return np.concatenate([np_list], **kwargs)
+
+
+def concat_multikeys_to_dict(
+ data_list: List[Dict], fixed_keys: List[Union[str, int]], keys: List[str], zero_insert: bool = True, **kwargs
+):
+ """
+ Get the nested value in a list of dictionary that shares the same structure iteratively on all keys.
+ It returns a dictionary with keys with the found values in nd.ndarray.
+
+ Args:
+ data_list: a list of dictionary {key1: {key2: np.ndarray}}.
+ fixed_keys: a list of keys that records to path to the value in the dict elements.
+ keys: a list of string keys that will be iterated to generate a dict output.
+ zero_insert: insert a zero in the list so that it can find the value in element 0 before getting the keys
+ flatten: if True, numbers are flattened before concat.
+
+ Returns:
+ a dict with keys - nd.array of concatenated array pair.
+ """
+
+ ret_dict = {}
+ for key in keys:
+ addon: List[Union[str, int]] = [0, key] if zero_insert else [key]
+ val = concat_val_to_np(data_list, fixed_keys + addon, **kwargs)
+ ret_dict.update({key: val})
+
+ return ret_dict
+
+
+def datafold_read(datalist: Union[str, Dict], basedir: str, fold: int = 0, key: str = "training") -> Tuple[List, List]:
+ """
+ Read a list of data dictionary `datalist`
+
+ Args:
+ datalist: the name of a JSON file listing the data, or a dictionary.
+ basedir: directory of image files.
+ fold: which fold to use (0..1 if in training set).
+ key: usually 'training' , but can try 'validation' or 'testing' to get the list data without labels (used in challenges).
+
+ Returns:
+ A tuple of two arrays (training, validation).
+ """
+
+ if isinstance(datalist, str):
+ json_data = ConfigParser.load_config_file(datalist)
+ else:
+ json_data = datalist
+
+ dict_data = deepcopy(json_data[key])
+
+ for d in dict_data:
+ for k, _ in d.items():
+ if isinstance(d[k], list):
+ d[k] = [os.path.join(basedir, iv) for iv in d[k]]
+ elif isinstance(d[k], str):
+ d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
+
+ tr = []
+ val = []
+ for d in dict_data:
+ if "fold" in d and d["fold"] == fold:
+ val.append(d)
+ else:
+ tr.append(d)
+
+ return tr, val
+
+
+def verify_report_format(report: dict, report_format: dict):
+ """
+ Compares the report and the report_format that has only keys.
+
+ Args:
+ report: dict that has real values.
+ report_format: dict that only has keys and list-nested value.
+ """
+ for k_fmt, v_fmt in report_format.items():
+ if k_fmt not in report:
+ return False
+
+ v = report[k_fmt]
+
+ if isinstance(v_fmt, list) and isinstance(v, list):
+ if len(v_fmt) != 1:
+ raise UserWarning("list length in report_format is not 1")
+ if len(v_fmt) > 0 and len(v) > 0:
+ return verify_report_format(v[0], v_fmt[0])
+ else:
+ return False
+
+ return True
+
+
+def algo_to_pickle(algo: Algo, **algo_meta_data) -> str:
+ """
+ Export the Algo object to pickle file
+
+ Args:
+ algo: Algo-like object
+ algo_meta_data: additional keyword to save into the dictionary. It may include template_path
+ which is used to instantiate the class. It may also include model training info
+ such as acc/best_metrics
+
+ Returns:
+ filename of the pickled Algo object
+ """
+ data = {"algo_bytes": pickle.dumps(algo)}
+ pkl_filename = os.path.join(algo.get_output_path(), "algo_object.pkl")
+ for k, v in algo_meta_data.items():
+ data.update({k: v})
+ data_bytes = pickle.dumps(data)
+ with open(pkl_filename, "wb") as f_pi:
+ f_pi.write(data_bytes)
+ return pkl_filename
+
+
+def algo_from_pickle(pkl_filename: str, **kwargs) -> Any:
+ """
+ Import the Algo object from a pickle file
+
+ Args:
+ pkl_filename: name of the pickle file
+ algo_templates_dir: the algorithm script folder which is needed to instantiate the object.
+ If it is None, the function will use the internal ``'algo_templates_dir`` in the object
+ dict.
+
+ Returns:
+ algo: Algo-like object
+
+ Raises:
+ ValueError if the pkl_filename does not contain a dict, or the dict does not contain
+ ``template_path`` or ``algo_bytes``
+ """
+ with open(pkl_filename, "rb") as f_pi:
+ data_bytes = f_pi.read()
+ data = pickle.loads(data_bytes)
+
+ if not isinstance(data, dict):
+ raise ValueError(f"the data object is {data.__class__}. Dict is expected.")
+
+ if "algo_bytes" not in data:
+ raise ValueError(f"key [algo_bytes] not found in {data}. Unable to instantiate.")
+
+ algo_bytes = data.pop("algo_bytes")
+ algo_meta_data = {}
+
+ if "template_path" in kwargs: # add template_path to sys.path
+ template_path = kwargs["template_path"]
+ if template_path is None: # then load template_path from pickled data
+ if "template_path" not in data:
+ raise ValueError(f"key [template_path] not found in {data}")
+ template_path = data.pop("template_path")
+
+ if not os.path.isdir(template_path):
+ raise ValueError(f"Algorithm templates {template_path} is not a directory")
+ # Example of template path: "algorithm_templates/dints".
+ sys.path.insert(0, os.path.abspath(os.path.join(template_path, "..")))
+ algo_meta_data.update({"template_path": template_path})
+
+ algo = pickle.loads(algo_bytes)
+ pkl_dir = os.path.dirname(pkl_filename)
+ if pkl_dir != algo.get_output_path():
+ warnings.warn(
+ f"{algo.get_output_path()} does not contain {pkl_filename}."
+ f"Now override the Algo output_path with: {pkl_dir}"
+ )
+ algo.output_path = pkl_dir
+
+ for k, v in data.items():
+ algo_meta_data.update({k: v})
+
+ return algo, algo_meta_data
diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py
index 2a658a3c51f..222799e0183 100644
--- a/monai/bundle/__init__.py
+++ b/monai/bundle/__init__.py
@@ -12,5 +12,16 @@
from .config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem, Instantiable
from .config_parser import ConfigParser
from .reference_resolver import ReferenceResolver
-from .scripts import ckpt_export, download, init_bundle, load, run, verify_metadata, verify_net_in_out
+from .scripts import (
+ ckpt_export,
+ download,
+ get_all_bundles_list,
+ get_bundle_info,
+ get_bundle_versions,
+ init_bundle,
+ load,
+ run,
+ verify_metadata,
+ verify_net_in_out,
+)
from .utils import EXPR_KEY, ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, load_bundle_config
diff --git a/monai/bundle/__main__.py b/monai/bundle/__main__.py
index ace3701d195..a9671fe3850 100644
--- a/monai/bundle/__main__.py
+++ b/monai/bundle/__main__.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from monai.bundle.scripts import ckpt_export, download, init_bundle, run, verify_metadata, verify_net_in_out
if __name__ == "__main__":
diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py
index 7e305f80306..0c46665bf50 100644
--- a/monai/bundle/config_item.py
+++ b/monai/bundle/config_item.py
@@ -11,7 +11,6 @@
import ast
import inspect
-import os
import sys
import warnings
from abc import ABC, abstractmethod
@@ -19,7 +18,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from monai.bundle.utils import EXPR_KEY
-from monai.utils import ensure_tuple, first, instantiate, optional_import
+from monai.utils import ensure_tuple, first, instantiate, optional_import, run_debug, run_eval
__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent", "Instantiable"]
@@ -107,7 +106,7 @@ def get_component_module_name(self, name: str) -> Optional[Union[List[str], str]
# init component and module mapping table
self._components_table = self._find_classes_or_functions(self._find_module_names())
- mods: Optional[Union[List[str], str]] = self._components_table.get(name, None)
+ mods: Optional[Union[List[str], str]] = self._components_table.get(name)
if isinstance(mods, list) and len(mods) == 1:
mods = mods[0]
return mods
@@ -176,6 +175,7 @@ class ConfigComponent(ConfigItem, Instantiable):
component doesn't explicitly depend on the other `ConfigItems` via its arguments,
but requires the dependencies to be instantiated/evaluated beforehand.
- ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation.
+ - ``"_desc_"`` (optional): free text descriptions of the component for code readability.
Other fields in the config content are input arguments to the python module.
@@ -203,7 +203,7 @@ class ConfigComponent(ConfigItem, Instantiable):
"""
- non_arg_keys = {"_target_", "_disabled_", "_requires_"}
+ non_arg_keys = {"_target_", "_disabled_", "_requires_", "_desc_"}
def __init__(
self,
@@ -257,7 +257,7 @@ def resolve_args(self):
"""
return {k: v for k, v in self.get_config().items() if k not in self.non_arg_keys}
- def is_disabled(self) -> bool: # type: ignore
+ def is_disabled(self) -> bool:
"""
Utility function used in `instantiate()` to check whether to skip the instantiation.
@@ -265,7 +265,7 @@ def is_disabled(self) -> bool: # type: ignore
_is_disabled = self.get_config().get("_disabled_", False)
return _is_disabled.lower().strip() == "true" if isinstance(_is_disabled, str) else bool(_is_disabled)
- def instantiate(self, **kwargs) -> object: # type: ignore
+ def instantiate(self, **kwargs) -> object:
"""
Instantiate component based on ``self.config`` content.
The target component must be a `class` or a `function`, otherwise, return `None`.
@@ -315,7 +315,7 @@ class ConfigExpression(ConfigItem):
"""
prefix = EXPR_KEY
- run_eval = os.environ.get("MONAI_EVAL_EXPR", "1") != "0"
+ run_eval = run_eval
def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None:
super().__init__(config=config, id=id)
@@ -364,7 +364,15 @@ def evaluate(self, globals: Optional[Dict] = None, locals: Optional[Dict] = None
if k in globals_:
warnings.warn(f"the new global variable `{k}` conflicts with `self.globals`, override it.")
globals_[k] = v
- return eval(value[len(self.prefix) :], globals_, locals)
+ if not run_debug:
+ return eval(value[len(self.prefix) :], globals_, locals)
+ warnings.warn(
+ f"\n\npdb: value={value}\n"
+ f"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\n"
+ )
+ import pdb
+
+ return pdb.run(value[len(self.prefix) :], globals_, locals)
@classmethod
def is_expression(cls, config: Union[Dict, List, str]) -> bool:
diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py
index b4ffc853bbd..d57238cfaa9 100644
--- a/monai/bundle/config_parser.py
+++ b/monai/bundle/config_parser.py
@@ -76,6 +76,7 @@ class ConfigParser:
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.
+ Set it to ``False`` to disable `self.globals` module importing.
See also:
@@ -95,14 +96,14 @@ def __init__(
self,
config: Any = None,
excludes: Optional[Union[Sequence[str], str]] = None,
- globals: Optional[Dict[str, Any]] = None,
+ globals: Union[Dict[str, Any], None, bool] = None,
):
self.config = None
self.globals: Dict[str, Any] = {}
_globals = _default_globals.copy()
- if isinstance(_globals, dict) and globals is not None:
- _globals.update(globals)
- if _globals is not None:
+ if isinstance(_globals, dict) and globals not in (None, False):
+ _globals.update(globals) # type: ignore
+ if _globals is not None and globals is not False:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v
@@ -132,8 +133,12 @@ def __getitem__(self, id: Union[str, int]):
for k in str(id).split(ID_SEP_KEY):
if not isinstance(config, (dict, list)):
raise ValueError(f"config must be dict or list for key `{k}`, but got {type(config)}: {config}.")
- indexing = k if isinstance(config, dict) else int(k)
- config = config[indexing]
+ try:
+ config = (
+ look_up_option(k, config, print_all_options=False) if isinstance(config, dict) else config[int(k)]
+ )
+ except ValueError as e:
+ raise KeyError(f"query key: {k}") from e
return config
def __setitem__(self, id: Union[str, int], config: Any):
@@ -157,6 +162,7 @@ def __setitem__(self, id: Union[str, int], config: Any):
# get the last parent level config item and replace it
last_id = ID_SEP_KEY.join(keys[:-1])
conf_ = self[last_id]
+
indexing = keys[-1] if isinstance(conf_, dict) else int(keys[-1])
conf_[indexing] = config
self.ref_resolver.reset()
@@ -173,18 +179,29 @@ def get(self, id: str = "", default: Optional[Any] = None):
"""
try:
return self[id]
- except KeyError:
+ except (KeyError, IndexError, ValueError): # Index error for integer indexing
return default
- def set(self, config: Any, id: str = ""):
+ def set(self, config: Any, id: str = "", recursive: bool = True):
"""
Set config by ``id``.
Args:
config: config to set at location ``id``.
id: id to specify the expected position. See also :py:meth:`__setitem__`.
+ recursive: if the nested id doesn't exist, whether to recursively create the nested items in the config.
+ default to `True`. for the nested id, only support `dict` for the missing section.
"""
+ keys = str(id).split(ID_SEP_KEY)
+ conf_ = self.get()
+ if recursive:
+ if conf_ is None:
+ self.config = conf_ = {} # type: ignore
+ for k in keys[:-1]:
+ if isinstance(conf_, dict) and k not in conf_:
+ conf_[k] = {}
+ conf_ = conf_[k if isinstance(conf_, dict) else int(k)]
self[id] = config
def update(self, pairs: Dict[str, Any]):
@@ -209,7 +226,7 @@ def __contains__(self, id: Union[str, int]) -> bool:
try:
_ = self[id]
return True
- except KeyError:
+ except (KeyError, IndexError, ValueError): # Index error for integer indexing
return False
def parse(self, reset: bool = True):
@@ -357,6 +374,8 @@ def load_config_file(cls, filepath: PathLike, **kwargs):
kwargs: other arguments for ``json.load`` or ```yaml.safe_load``, depends on the file format.
"""
+ if not filepath:
+ return {}
_filepath: str = str(Path(filepath))
if not re.compile(cls.path_match, re.IGNORECASE).findall(_filepath):
raise ValueError(f'unknown file input: "{filepath}"')
@@ -403,7 +422,8 @@ def export_config_file(cls, config: Dict, filepath: PathLike, fmt="json", **kwar
writer = look_up_option(fmt.lower(), {"json", "yaml"})
with open(_filepath, "w") as f:
if writer == "json":
- return json.dump(config, f, **kwargs)
+ json.dump(config, f, **kwargs)
+ return
if writer == "yaml":
return yaml.safe_dump(config, f, **kwargs)
raise ValueError(f"only support JSON or YAML config file so far, got {writer}.")
diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py
index 5ad46518fd6..ed40d7099ae 100644
--- a/monai/bundle/reference_resolver.py
+++ b/monai/bundle/reference_resolver.py
@@ -9,14 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import re
import warnings
from typing import Any, Dict, Optional, Sequence, Set
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY
-from monai.utils import look_up_option
+from monai.utils import allow_missing_reference, look_up_option
__all__ = ["ReferenceResolver"]
@@ -53,7 +52,7 @@ class ReferenceResolver:
# match a reference string, e.g. "@id#key", "@id#key#0", "@_target_#key"
id_matcher = re.compile(rf"{ref}(?:\w*)(?:{sep}\w*)*")
# if `allow_missing_reference` and can't find a reference ID, will just raise a warning and don't update the config
- allow_missing_reference = os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "0") != "0"
+ allow_missing_reference = allow_missing_reference
def __init__(self, items: Optional[Sequence[ConfigItem]] = None):
# save the items in a dictionary with the `ConfigItem.id` as key
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index c01498bb093..27227df3ea7 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -153,6 +153,7 @@ def _process_bundle_dir(bundle_dir: Optional[PathLike] = None):
def download(
name: Optional[str] = None,
+ version: Optional[str] = None,
bundle_dir: Optional[PathLike] = None,
source: str = "github",
repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
@@ -170,11 +171,14 @@ def download(
.. code-block:: bash
+ # Execute this module as a CLI entry, and download bundle from the model-zoo repo:
+ python -m monai.bundle download --name --version "0.1.0" --bundle_dir "./"
+
# Execute this module as a CLI entry, and download bundle:
- python -m monai.bundle download --name "bundle_name" --source "github" --repo "repo_owner/repo_name/release_tag"
+ python -m monai.bundle download --name --source "github" --repo "repo_owner/repo_name/release_tag"
# Execute this module as a CLI entry, and download bundle via URL:
- python -m monai.bundle download --name "bundle_name" --url
+ python -m monai.bundle download --name --url
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime.
@@ -185,6 +189,9 @@ def download(
Args:
name: bundle name. If `None` and `url` is `None`, it must be provided in `args_file`.
+ for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
+ https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
+ version: version name of the target bundle to download, like: "0.1.0".
bundle_dir: target directory to store the downloaded data.
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
@@ -200,19 +207,28 @@ def download(
"""
_args = _update_args(
- args=args_file, name=name, bundle_dir=bundle_dir, source=source, repo=repo, url=url, progress=progress
+ args=args_file,
+ name=name,
+ version=version,
+ bundle_dir=bundle_dir,
+ source=source,
+ repo=repo,
+ url=url,
+ progress=progress,
)
_log_input_summary(tag="download", args=_args)
- source_, repo_, progress_, name_, bundle_dir_, url_ = _pop_args(
- _args, "source", "repo", "progress", name=None, bundle_dir=None, url=None
+ source_, repo_, progress_, name_, version_, bundle_dir_, url_ = _pop_args(
+ _args, "source", "repo", "progress", name=None, version=None, bundle_dir=None, url=None
)
bundle_dir_ = _process_bundle_dir(bundle_dir_)
+ if name_ is not None and version_ is not None:
+ name_ = "_v".join([name_, version_])
if url_ is not None:
- if name is not None:
- filepath = bundle_dir_ / f"{name}.zip"
+ if name_ is not None:
+ filepath = bundle_dir_ / f"{name_}.zip"
else:
filepath = bundle_dir_ / f"{_basename(url_)}"
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
@@ -229,6 +245,7 @@ def download(
def load(
name: str,
+ version: Optional[str] = None,
model_file: Optional[str] = None,
load_ts_module: bool = False,
bundle_dir: Optional[PathLike] = None,
@@ -245,7 +262,9 @@ def load(
Load model weights or TorchScript module of a bundle.
Args:
- name: bundle name.
+ name: bundle name, for example: "spleen_ct_segmentation", "prostate_mri_anatomy" in the model-zoo:
+ https://github.com/Project-MONAI/model-zoo/releases/tag/hosting_storage_v1.
+ version: version name of the target bundle to download, like: "0.1.0".
model_file: the relative path of the model weights or TorchScript module within bundle.
If `None`, "models/model.pt" or "models/model.ts" will be used.
load_ts_module: a flag to specify if loading the TorchScript module.
@@ -280,7 +299,7 @@ def load(
model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
full_path = os.path.join(bundle_dir_, name, model_file)
if not os.path.exists(full_path):
- download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
+ download(name=name, version=version, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)
if device is None:
device = "cuda:0" if is_available() else "cpu"
@@ -303,6 +322,153 @@ def load(
return model
+def _get_all_bundles_info(
+ repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: Optional[str] = None
+):
+ if has_requests:
+ request_url = f"https://api.github.com/repos/{repo}/releases"
+ if auth_token is not None:
+ headers = {"Authorization": f"Bearer {auth_token}"}
+ resp = requests_get(request_url, headers=headers)
+ else:
+ resp = requests_get(request_url)
+ resp.raise_for_status()
+ else:
+ raise ValueError("requests package is required, please install it.")
+ releases_list = json.loads(resp.text)
+ bundle_name_pattern = re.compile(r"_v\d*.")
+ bundles_info: Dict = {}
+
+ for release in releases_list:
+ if release["tag_name"] == tag:
+ for asset in release["assets"]:
+ asset_name = bundle_name_pattern.split(asset["name"])[0]
+ if asset_name not in bundles_info:
+ bundles_info[asset_name] = {}
+ asset_version = asset["name"].split(f"{asset_name}_v")[-1].replace(".zip", "")
+ bundles_info[asset_name][asset_version] = {
+ "id": asset["id"],
+ "name": asset["name"],
+ "size": asset["size"],
+ "download_count": asset["download_count"],
+ "browser_download_url": asset["browser_download_url"],
+ "created_at": asset["created_at"],
+ "updated_at": asset["updated_at"],
+ }
+ return bundles_info
+ return bundles_info
+
+
+def get_all_bundles_list(
+ repo: str = "Project-MONAI/model-zoo", tag: str = "hosting_storage_v1", auth_token: Optional[str] = None
+):
+ """
+ Get all bundles names (and the latest versions) that are stored in the release of specified repository
+ with the provided tag. The default values of arguments correspond to the release of MONAI model zoo.
+ In order to increase the rate limits of calling GIthub APIs, you can input your personal access token.
+ Please check the following link for more details about rate limiting:
+ https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
+
+ The following link shows how to create your personal access token:
+ https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token
+
+ Args:
+ repo: it should be in the form of "repo_owner/repo_name/".
+ tag: the tag name of the release.
+ auth_token: github personal access token.
+
+ Returns:
+ a list of tuple in the form of (bundle name, latest version).
+
+ """
+
+ bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
+ bundles_list = []
+ for bundle_name in bundles_info.keys():
+ latest_version = sorted(bundles_info[bundle_name].keys())[-1]
+ bundles_list.append((bundle_name, latest_version))
+
+ return bundles_list
+
+
+def get_bundle_versions(
+ bundle_name: str,
+ repo: str = "Project-MONAI/model-zoo",
+ tag: str = "hosting_storage_v1",
+ auth_token: Optional[str] = None,
+):
+ """
+ Get the latest version, as well as all existing versions of a bundle that is stored in the release of specified
+ repository with the provided tag.
+ In order to increase the rate limits of calling GIthub APIs, you can input your personal access token.
+ Please check the following link for more details about rate limiting:
+ https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
+
+ The following link shows how to create your personal access token:
+ https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token
+
+ Args:
+ bundle_name: bundle name.
+ repo: it should be in the form of "repo_owner/repo_name/".
+ tag: the tag name of the release.
+ auth_token: github personal access token.
+
+ Returns:
+ a dictionary that contains the latest version and all versions of a bundle.
+
+ """
+
+ bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
+ if bundle_name not in bundles_info:
+ raise ValueError(f"bundle: {bundle_name} is not existing.")
+ bundle_info = bundles_info[bundle_name]
+ all_versions = sorted(bundle_info.keys())
+
+ return {"latest_version": all_versions[-1], "all_versions": all_versions}
+
+
+def get_bundle_info(
+ bundle_name: str,
+ version: Optional[str] = None,
+ repo: str = "Project-MONAI/model-zoo",
+ tag: str = "hosting_storage_v1",
+ auth_token: Optional[str] = None,
+):
+ """
+ Get all information
+ (include "id", "name", "size", "download_count", "browser_download_url", "created_at", "updated_at") of a bundle
+ with the specified bundle name and version.
+ In order to increase the rate limits of calling GIthub APIs, you can input your personal access token.
+ Please check the following link for more details about rate limiting:
+ https://docs.github.com/en/rest/overview/resources-in-the-rest-api#rate-limiting
+
+ The following link shows how to create your personal access token:
+ https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token
+
+ Args:
+ bundle_name: bundle name.
+ version: version name of the target bundle, if None, the latest version will be used.
+ repo: it should be in the form of "repo_owner/repo_name/".
+ tag: the tag name of the release.
+ auth_token: github personal access token.
+
+ Returns:
+ a dictionary that contains the bundle's information.
+
+ """
+
+ bundles_info = _get_all_bundles_info(repo=repo, tag=tag, auth_token=auth_token)
+ if bundle_name not in bundles_info:
+ raise ValueError(f"bundle: {bundle_name} is not existing.")
+ bundle_info = bundles_info[bundle_name]
+ if version is None:
+ version = sorted(bundle_info.keys())[-1]
+ if version not in bundle_info:
+ raise ValueError(f"version: {version} of bundle: {bundle_name} is not existing.")
+
+ return bundle_info[version]
+
+
def run(
runner_id: Optional[Union[str, Sequence[str]]] = None,
meta_file: Optional[Union[str, Sequence[str]]] = None,
@@ -357,12 +523,14 @@ def run(
**override,
)
if "config_file" not in _args:
- raise ValueError(f"`config_file` is required for 'monai.bundle run'.\n{run.__doc__}")
+ warnings.warn("`config_file` not provided for 'monai.bundle run'.")
_log_input_summary(tag="run", args=_args)
config_file_, meta_file_, runner_id_, logging_file_ = _pop_args(
- _args, "config_file", meta_file=None, runner_id="", logging_file=None
+ _args, config_file=None, meta_file=None, runner_id="", logging_file=None
)
if logging_file_ is not None:
+ if not os.path.exists(logging_file_):
+ raise FileNotFoundError(f"can't find the logging config file: {logging_file_}.")
logger.info(f"set logging properties based on config: {logging_file_}.")
fileConfig(logging_file_, disable_existing_loggers=False)
diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py
index c3a8343163e..f382fd820e6 100644
--- a/monai/bundle/utils.py
+++ b/monai/bundle/utils.py
@@ -21,13 +21,11 @@
__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY"]
-
ID_REF_KEY = "@" # start of a reference to a ConfigItem
ID_SEP_KEY = "#" # separator for the ID of a ConfigItem
EXPR_KEY = "$" # start of a ConfigExpression
MACRO_KEY = "%" # start of a macro of a config
-
_conf_values = get_config_values()
DEFAULT_METADATA = {
diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py
index bb6f87e97ac..5c360b5536c 100644
--- a/monai/config/type_definitions.py
+++ b/monai/config/type_definitions.py
@@ -41,7 +41,6 @@
"SequenceStr",
]
-
#: KeysCollection
#
# The KeyCollection type is used to for defining variables
diff --git a/monai/data/__init__.py b/monai/data/__init__.py
index 3dbc6649f21..65ee8c377ff 100644
--- a/monai/data/__init__.py
+++ b/monai/data/__init__.py
@@ -106,8 +106,11 @@
worker_init_fn,
zoom_affine,
)
+
+# FIXME: workaround for https://github.com/Project-MONAI/MONAI/issues/5291
+# from .video_dataset import CameraDataset, VideoDataset, VideoFileDataset
from .wsi_datasets import MaskedPatchWSIDataset, PatchWSIDataset, SlidingPatchWSIDataset
-from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, WSIReader
+from .wsi_reader import BaseWSIReader, CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader, WSIReader
with contextlib.suppress(BaseException):
from multiprocessing.reduction import ForkingPickler
diff --git a/monai/data/box_utils.py b/monai/data/box_utils.py
index afe8e111675..a1e321b623d 100644
--- a/monai/data/box_utils.py
+++ b/monai/data/box_utils.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
This utility module mainly supports rectangular bounding boxes with a few
different parameterizations and methods for converting between them. It
@@ -36,7 +35,6 @@
# We support 2-D or 3-D bounding boxes
SUPPORTED_SPATIAL_DIMS = [2, 3]
-
# TO_REMOVE = 0.0 if the bottom-right corner pixel/voxel is not included in the boxes,
# i.e., when xmin=1., xmax=2., we have w = 1.
# TO_REMOVE = 1.0 if the bottom-right corner pixel/voxel is included in the boxes,
@@ -657,7 +655,7 @@ def boxes_center_distance(
center2 = box_centers(boxes2_t.to(COMPUTE_DTYPE)) # (M, spatial_dims)
if euclidean:
- dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt()
+ dists = (center1[:, None] - center2[None]).pow(2).sum(-1).sqrt() # type: ignore
else:
# before sum: (N, M, spatial_dims)
dists = (center1[:, None] - center2[None]).sum(-1)
diff --git a/monai/data/dataloader.py b/monai/data/dataloader.py
index d1f5bd4fe11..f43211f1849 100644
--- a/monai/data/dataloader.py
+++ b/monai/data/dataloader.py
@@ -66,6 +66,8 @@ def __len__(self):
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
+ collate_fn: default to :py:func:`monai.data.utils.list_data_collate`.
+ worker_init_fn: default to :py:func:`monai.data.utils.worker_init_fn`.
kwargs: other parameters for PyTorch DataLoader.
"""
@@ -74,11 +76,14 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
# when num_workers > 0, random states are determined by worker_init_fn
# this is to make the behavior consistent when num_workers == 0
# torch.int64 doesn't work well on some versions of windows
- _seed = torch.empty((), dtype=torch.int32).random_(generator=None).item()
+ _g = torch.random.default_generator if kwargs.get("generator") is None else kwargs["generator"]
+ init_seed = _g.initial_seed()
+ _seed = torch.empty((), dtype=torch.int64).random_(generator=_g).item()
set_rnd(dataset, int(_seed))
+ _g.manual_seed(init_seed)
if "collate_fn" not in kwargs:
- kwargs.update({"collate_fn": list_data_collate})
+ kwargs["collate_fn"] = list_data_collate
if "worker_init_fn" not in kwargs:
- kwargs.update({"worker_init_fn": worker_init_fn})
+ kwargs["worker_init_fn"] = worker_init_fn
super().__init__(dataset=dataset, num_workers=num_workers, **kwargs)
diff --git a/monai/data/dataset.py b/monai/data/dataset.py
index dd66d7ee147..2c263b3e32e 100644
--- a/monai/data/dataset.py
+++ b/monai/data/dataset.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import collections.abc
import math
import pickle
@@ -790,7 +789,8 @@ def __init__(
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self.cache_num = 0
- self._cache: Union[List, Dict] = []
+ self._cache: List = []
+ self._hash_keys: List = []
self.set_data(data)
def set_data(self, data: Sequence):
@@ -802,37 +802,41 @@ def set_data(self, data: Sequence):
generated cache content.
"""
+ self.data = data
- def _compute_cache():
- self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data))
- return self._fill_cache()
+ def _compute_cache_num(data_len: int):
+ self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)
if self.hash_as_key:
- # only compute cache for the unique items of dataset
- mapping = {self.hash_func(v): v for v in data}
- self.data = list(mapping.values())
- cache_ = _compute_cache()
- self._cache = dict(zip(list(mapping)[: self.cache_num], cache_))
- self.data = data
+ # only compute cache for the unique items of dataset, and record the last index for duplicated items
+ mapping = {self.hash_func(v): i for i, v in enumerate(data)}
+ _compute_cache_num(len(mapping))
+ self._hash_keys = list(mapping)[: self.cache_num]
+ indices = list(mapping.values())[: self.cache_num]
else:
- self.data = data
- self._cache = _compute_cache()
+ _compute_cache_num(len(self.data))
+ indices = list(range(self.cache_num))
+ self._cache = self._fill_cache(indices)
+
+ def _fill_cache(self, indices=None) -> List:
+ """
+ Compute and fill the cache content from data source.
+
+ Args:
+ indices: target indices in the `self.data` source to compute cache.
+ if None, use the first `cache_num` items.
- def _fill_cache(self) -> List:
+ """
if self.cache_num <= 0:
return []
+ if indices is None:
+ indices = list(range(self.cache_num))
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if self.progress and has_tqdm:
- return list(
- tqdm(
- p.imap(self._load_cache_item, range(self.cache_num)),
- total=self.cache_num,
- desc="Loading dataset",
- )
- )
- return list(p.imap(self._load_cache_item, range(self.cache_num)))
+ return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
+ return list(p.imap(self._load_cache_item, indices))
def _load_cache_item(self, idx: int):
"""
@@ -851,21 +855,24 @@ def _load_cache_item(self, idx: int):
return item
def _transform(self, index: int):
- index_: Any = index
+ cache_index = None
if self.hash_as_key:
key = self.hash_func(self.data[index])
- if key in self._cache:
- # if existing in cache, get the index
- index_ = key # if using hash as cache keys, set the key
+ if key in self._hash_keys:
+ # if existing in cache, try to get the index in cache
+ cache_index = self._hash_keys.index(key)
+ elif index % len(self) < self.cache_num: # support negative index
+ cache_index = index
- if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index
+ if cache_index is None:
# no cache for this index, execute all the transforms directly
- return super()._transform(index_)
+ return super()._transform(index)
+
+ if self._cache is None:
+ raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.")
+ data = self._cache[cache_index]
# load data from cache and execute from the first random transform
start_run = False
- if self._cache is None:
- self._cache = self._fill_cache()
- data = self._cache[index_]
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
@@ -1428,7 +1435,7 @@ class CSVDataset(Dataset):
@deprecated_arg(name="filename", new_name="src", since="0.8", msg_suffix="please use `src` instead.")
def __init__(
self,
- src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or sequense of `DataFrame`
+ src: Optional[Union[str, Sequence[str]]] = None, # also can be `DataFrame` or a sequence of `DataFrame`
row_indices: Optional[Sequence[Union[int, str]]] = None,
col_names: Optional[Sequence[str]] = None,
col_types: Optional[Dict[str, Optional[Dict[str, Any]]]] = None,
diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py
index ec644afaf62..34e1368fe2d 100644
--- a/monai/data/image_reader.py
+++ b/monai/data/image_reader.py
@@ -28,7 +28,16 @@
orientation_ras_lps,
)
from monai.transforms.utility.array import EnsureChannelFirst
-from monai.utils import MetaKeys, SpaceKeys, ensure_tuple, ensure_tuple_rep, optional_import, require_pkg
+from monai.utils import (
+ MetaKeys,
+ SpaceKeys,
+ TraceKeys,
+ deprecated,
+ ensure_tuple,
+ ensure_tuple_rep,
+ optional_import,
+ require_pkg,
+)
if TYPE_CHECKING:
import itk
@@ -131,7 +140,7 @@ def _copy_compatible_dict(from_dict: Dict, to_dict: Dict):
datum = from_dict[key]
if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:
continue
- to_dict[key] = datum
+ to_dict[key] = str(TraceKeys.NONE) if datum is None else datum # NoneType to string for default_collate
else:
affine_key, shape_key = MetaKeys.AFFINE, MetaKeys.SPATIAL_SHAPE
if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):
@@ -309,7 +318,12 @@ def _get_meta_dict(self, img) -> Dict:
"""
img_meta_dict = img.GetMetaDataDictionary()
- meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.GetKeys() if not key.startswith("ITK_")}
+ meta_dict = {}
+ for key in img_meta_dict.GetKeys():
+ if key.startswith("ITK_"):
+ continue
+ val = img_meta_dict[key]
+ meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val
meta_dict["spacing"] = np.asarray(img.GetSpacing())
return meta_dict
@@ -1218,6 +1232,7 @@ def _get_spatial_shape(self, img):
return np.asarray((img.width, img.height))
+@deprecated(since="0.8", msg_suffix="use `monai.wsi_reader.WSIReader` instead.")
class WSIReader(ImageReader):
"""
Read whole slide images and extract patches.
diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py
index cf7f3e02913..95162a58ea1 100644
--- a/monai/data/image_writer.py
+++ b/monai/data/image_writer.py
@@ -46,7 +46,6 @@
nib, _ = optional_import("nibabel")
PILImage, _ = optional_import("PIL.Image")
-
__all__ = [
"ImageWriter",
"ITKWriter",
@@ -602,6 +601,12 @@ def write(self, filename: PathLike, verbose: bool = False, **obj_kwargs):
self.data_obj = self.create_backend_obj(
self.data_obj, affine=self.affine, dtype=self.output_dtype, **obj_kwargs # type: ignore
)
+ if self.affine is None:
+ self.affine = np.eye(4)
+ # ITK v5.2.1/Modules/IO/NIFTI/src/itkNiftiImageIO.cxx#L2175-L2176
+ _affine = to_affine_nd(r=3, affine=convert_data_type(self.affine, np.ndarray)[0])
+ self.data_obj.set_sform(_affine, code=1)
+ self.data_obj.set_qform(_affine, code=1)
nib.save(self.data_obj, filename)
@classmethod
diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py
index f1906a80fe8..009bd31031e 100644
--- a/monai/data/iterable_dataset.py
+++ b/monai/data/iterable_dataset.py
@@ -72,6 +72,25 @@ class ShuffleBuffer(Randomizable, IterableDataset):
every iter() call, refer to the PyTorch idea:
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
+ Note:
+ Both ``monai.data.DataLoader`` and ``torch.utils.data.DataLoader`` do not seed this class (as a subclass of
+ ``IterableDataset``) at run time. ``persistent_workers=True`` flag (and pytorch>1.8) is therefore required
+ for multiple epochs of loading when ``num_workers>0``. For example::
+
+ import monai
+
+ def run():
+ dss = monai.data.ShuffleBuffer([1, 2, 3, 4], buffer_size=30, seed=42)
+
+ dataloader = monai.data.DataLoader(
+ dss, batch_size=1, num_workers=2, persistent_workers=True)
+ for epoch in range(3):
+ for item in dataloader:
+ print(f"epoch: {epoch} item: {item}.")
+
+ if __name__ == '__main__':
+ run()
+
"""
def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) -> None:
@@ -80,36 +99,31 @@ def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0)
self.seed = seed
self._idx = 0
+ def randomized_pop(self, buffer):
+ """Return the item at a randomized location `self._idx` in `buffer`."""
+ self.randomize(len(buffer))
+ ret, buffer[self._idx] = buffer[self._idx], buffer[-1]
+ buffer.pop()
+ return ret
+
+ def generate_item(self):
+ """Fill a `buffer` list up to `self.size`, then generate randomly popped items."""
+ buffer = []
+ for item in iter(self.data):
+ if len(buffer) >= self.size:
+ yield self.randomized_pop(buffer)
+ buffer.append(item)
+ while buffer:
+ yield self.randomized_pop(buffer)
+
def __iter__(self):
"""
- Fetch data from the source, if buffer is not full, fill into buffer, otherwise,
- randomly pop items from the buffer.
- After loading all the data from source, randomly pop items from the buffer.
-
+ Randomly pop buffered items from `self.data`.
+ Multiple dataloader workers sharing this dataset will generate identical item sequences.
"""
self.seed += 1
super().set_random_state(seed=self.seed) # make all workers in sync
- buffer = []
- source = self.data
-
- def _pop_item():
- self.randomize(len(buffer))
- # switch random index data and the last index data
- ret, buffer[self._idx] = buffer[self._idx], buffer[-1]
- buffer.pop()
- return ret
-
- def _get_item():
- for item in source:
- if len(buffer) >= self.size:
- yield _pop_item()
- buffer.append(item)
-
- while buffer:
- yield _pop_item()
-
- self.data = _get_item()
- return super().__iter__()
+ yield from IterableDataset(self.generate_item(), transform=self.transform)
def randomize(self, size: int) -> None:
self._idx = self.R.randint(size)
diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py
index 3a1bee508c0..6aab05dc94a 100644
--- a/monai/data/meta_obj.py
+++ b/monai/data/meta_obj.py
@@ -82,6 +82,7 @@ class MetaObj:
def __init__(self):
self._meta: dict = MetaObj.get_default_meta()
self._applied_operations: list = MetaObj.get_default_applied_operations()
+ self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops
self._is_batch: bool = False
@staticmethod
@@ -174,7 +175,8 @@ def meta(self, d) -> None:
"""Set the meta."""
if d == TraceKeys.NONE:
self._meta = MetaObj.get_default_meta()
- self._meta = d
+ else:
+ self._meta = d
@property
def applied_operations(self) -> list[dict]:
@@ -198,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None:
def pop_applied_operation(self) -> Any:
return self._applied_operations.pop()
+ @property
+ def pending_operations(self) -> list[dict]:
+ """Get the pending operations. Defaults to ``[]``."""
+ if hasattr(self, "_pending_operations"):
+ return self._pending_operations
+ return MetaObj.get_default_applied_operations() # the same default as applied_ops
+
+ def push_pending_operation(self, t: Any) -> None:
+ self._pending_operations.append(t)
+
+ def pop_pending_operation(self) -> Any:
+ return self._pending_operations.pop()
+
@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py
index 5911e218ee8..9ed08aa6cdf 100644
--- a/monai/data/meta_tensor.py
+++ b/monai/data/meta_tensor.py
@@ -23,8 +23,8 @@
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
-from monai.utils.enums import MetaKeys, PostFix, SpaceKeys
-from monai.utils.type_conversion import convert_data_type, convert_to_tensor
+from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
+from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor
__all__ = ["MetaTensor"]
@@ -43,7 +43,7 @@ class MetaTensor(MetaObj, torch.Tensor):
* For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the
first instance of `MetaTensor` if `a.is_batch` is False
- (For batched data, the metdata will be shallow copied for efficiency purposes).
+ (For batched data, the metadata will be shallow copied for efficiency purposes).
Example:
.. code-block:: python
@@ -312,7 +312,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
except AttributeError:
return NotImplemented
- def get_default_affine(self, dtype=torch.float64) -> torch.Tensor:
+ @staticmethod
+ def get_default_affine(dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=torch.device("cpu"), dtype=dtype)
def as_tensor(self) -> torch.Tensor:
@@ -320,7 +321,7 @@ def as_tensor(self) -> torch.Tensor:
Return the `MetaTensor` as a `torch.Tensor`.
It is OS dependent as to whether this will be a deep copy or not.
"""
- return self.as_subclass(torch.Tensor) # type: ignore
+ return self.as_subclass(torch.Tensor)
def get_array(self, output_type=np.ndarray, dtype=None, device=None, *_args, **_kwargs):
"""
@@ -444,6 +445,20 @@ def pixdim(self):
return [affine_to_spacing(a) for a in self.affine]
return affine_to_spacing(self.affine)
+ def peek_pending_shape(self):
+ """Get the currently expected spatial shape as if all the pending operations are executed."""
+ res = None
+ if self.pending_operations:
+ res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
+ # default to spatial shape (assuming channel-first input)
+ return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res
+
+ def peek_pending_affine(self):
+ res = None
+ if self.pending_operations:
+ res = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
+ return self.affine if res is None else res
+
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
"""
must be defined for deepcopy to work
@@ -503,4 +518,16 @@ def ensure_torch_and_prune_meta(
return MetaTensor(img, meta=meta)
def __repr__(self, *, tensor_contents=None):
+ """
+ Prints out a long representation of the MetaTensor object with metadata as well as content data.
+
+ Args:
+ tensor_contents: currently unused
+ """
return self.as_tensor().__repr__() + super().__repr__()
+
+ def __str__(self):
+ """
+ Prints a simpler representation of the tensor identical to torch.Tensor.__str__.
+ """
+ return str(self.as_tensor())
diff --git a/monai/data/synthetic.py b/monai/data/synthetic.py
index 46d555cf114..7f51b687fbe 100644
--- a/monai/data/synthetic.py
+++ b/monai/data/synthetic.py
@@ -48,6 +48,9 @@ def create_test_image_2d(
channel_dim: if None, create an image without channel dimension, otherwise create
an image with channel dimension as first dim or last dim. Defaults to `None`.
random_state: the random generator to use. Defaults to `np.random`.
+
+ Returns:
+ Randomised Numpy array with shape (`width`, `height`)
"""
if rad_max <= rad_min:
@@ -120,6 +123,9 @@ def create_test_image_3d(
an image with channel dimension as first dim or last dim. Defaults to `None`.
random_state: the random generator to use. Defaults to `np.random`.
+ Returns:
+ Randomised Numpy array with shape (`width`, `height`, `depth`)
+
See also:
:py:meth:`~create_test_image_2d`
"""
diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py
index b882c2533c1..0a9079ae0c5 100644
--- a/monai/data/thread_buffer.py
+++ b/monai/data/thread_buffer.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from multiprocessing.context import SpawnContext
from queue import Empty, Full, Queue
from threading import Thread
@@ -83,7 +82,7 @@ def __iter__(self):
def buffer_iterator(src, buffer_size: int = 1, timeout: float = 0.01, repeats: int = 1):
"""
Create a ThreadBuffer object using the `src`, `buffer_size`, and `timeout` parameters given for the constructor
- aguments of the same names, and yield each generated object `repeats` number of times successively.
+ arguments of the same names, and yield each generated object `repeats` number of times successively.
Args:
src: Source data iterable
diff --git a/monai/data/utils.py b/monai/data/utils.py
index cf426d4c3e5..0fb5c9a33a3 100644
--- a/monai/data/utils.py
+++ b/monai/data/utils.py
@@ -30,7 +30,6 @@
from monai import config
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike
from monai.data.meta_obj import MetaObj
-from monai.networks.layers.simplelayers import GaussianFilter
from monai.utils import (
MAX_SEED,
BlendMode,
@@ -53,7 +52,6 @@
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")
-
__all__ = [
"AFFINE_TOL",
"SUPPORTED_PICKLE_MOD",
@@ -670,6 +668,11 @@ def set_rnd(obj, seed: int) -> int:
obj: object to set seed or random state for.
seed: set the random state with an integer seed.
"""
+ if isinstance(obj, (tuple, list)): # ZipDataset.data is a list
+ _seed = seed
+ for item in obj:
+ _seed = set_rnd(item, seed=seed)
+ return seed if _seed == seed else seed + 1 # return a different seed if there are randomizable items
if not hasattr(obj, "__dict__"):
return seed # no attribute
if hasattr(obj, "set_random_state"):
@@ -690,7 +693,7 @@ def affine_to_spacing(affine: NdarrayTensor, r: int = 3, dtype=float, suppress_z
affine: a d x d affine matrix.
r: indexing based on the spatial rank, spacing is computed from `affine[:r, :r]`.
dtype: data type of the output.
- suppress_zeros: whether to surpress the zeros with ones.
+ suppress_zeros: whether to suppress the zeros with ones.
Returns:
an `r` dimensional vector of spacing.
@@ -822,7 +825,10 @@ def zoom_affine(affine: np.ndarray, scale: Union[np.ndarray, Sequence[float]], d
def compute_shape_offset(
- spatial_shape: Union[np.ndarray, Sequence[int]], in_affine: NdarrayOrTensor, out_affine: NdarrayOrTensor
+ spatial_shape: Union[np.ndarray, Sequence[int]],
+ in_affine: NdarrayOrTensor,
+ out_affine: NdarrayOrTensor,
+ scale_extent: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given input and output affine, compute appropriate shapes
@@ -834,12 +840,29 @@ def compute_shape_offset(
spatial_shape: input array's shape
in_affine (matrix): 2D affine matrix
out_affine (matrix): 2D affine matrix
+ scale_extent: whether the scale is computed based on the spacing or the full extent of voxels, for example, for
+ a factor of 0.5 scaling:
+
+ option 1, "o" represents a voxel, scaling the distance between voxels::
+
+ o--o--o
+ o-----o
+
+ option 2, each voxel has a physical extent, scaling the full voxel extent::
+
+ | voxel 1 | voxel 2 | voxel 3 | voxel 4 |
+ | voxel 1 | voxel 2 |
+
+ Option 1 may reduce the number of locations that requiring interpolation. Option 2 is more resolution
+ agnostic, that is, resampling coordinates depend on the scaling factor, not on the number of voxels.
+ Default is False, using option 1 to compute the shape and offset.
+
"""
shape = np.array(spatial_shape, copy=True, dtype=float)
sr = len(shape)
in_affine_ = convert_data_type(to_affine_nd(sr, in_affine), np.ndarray)[0]
out_affine_ = convert_data_type(to_affine_nd(sr, out_affine), np.ndarray)[0]
- in_coords = [(0.0, dim - 1.0) for dim in shape]
+ in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape]
corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1))
corners = np.concatenate((corners, np.ones_like(corners[:1])))
corners = in_affine_ @ corners
@@ -849,16 +872,20 @@ def compute_shape_offset(
raise ValueError(f"Affine {out_affine_} is not invertible") from e
corners_out = inv_mat @ corners
corners_out = corners_out[:-1] / corners_out[-1]
- out_shape = np.round(corners_out.ptp(axis=1) + 1.0)
- mat = inv_mat[:-1, :-1]
- k = 0
+ out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
+ all_dist = inv_mat[:-1, :-1] @ corners[:-1, :]
+ offset = None
for i in range(corners.shape[1]):
- min_corner = np.min(mat @ corners[:-1, :] - mat @ corners[:-1, i : i + 1], 1)
+ min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
if np.allclose(min_corner, 0.0, rtol=AFFINE_TOL):
- k = i
+ offset = corners[:-1, i] # corner is the smallest, shift the corner to origin
break
- offset = corners[:-1, k]
- return out_shape.astype(int, copy=False), offset
+ if offset is None: # otherwise make output image center aligned with the input image center
+ offset = in_affine_[:-1, :-1] @ (shape / 2.0) + in_affine_[:-1, -1] - out_affine_[:-1, :-1] @ (out_shape / 2.0)
+ if scale_extent:
+ in_offset = np.append(0.5 * (shape / out_shape - 1.0), 1.0)
+ offset = np.abs((in_affine_ @ in_offset / in_offset[-1])[:-1]) * np.sign(offset)
+ return out_shape.astype(int, copy=False), offset # type: ignore
def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.float64) -> NdarrayTensor:
@@ -1038,17 +1065,16 @@ def compute_importance_map(
if mode == BlendMode.CONSTANT:
importance_map = torch.ones(patch_size, device=device, dtype=torch.float)
elif mode == BlendMode.GAUSSIAN:
- center_coords = [i // 2 for i in patch_size]
+
sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
- importance_map = torch.zeros(patch_size, device=device)
- importance_map[tuple(center_coords)] = 1
- pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(device=device, dtype=torch.float)
- importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
- importance_map = importance_map.squeeze(0).squeeze(0)
- importance_map = importance_map / torch.max(importance_map)
- importance_map = importance_map.float()
+ for i in range(len(patch_size)):
+ x = torch.arange(
+ start=-(patch_size[i] - 1) / 2.0, end=(patch_size[i] - 1) / 2.0 + 1, dtype=torch.float, device=device
+ )
+ x = torch.exp(x**2 / (-2 * sigmas[i] ** 2)) # 1D gaussian
+ importance_map = importance_map.unsqueeze(-1) * x[(None,) * i] if i > 0 else x
else:
raise ValueError(
f"Unsupported mode: {mode}, available options are [{BlendMode.CONSTANT}, {BlendMode.CONSTANT}]."
diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py
new file mode 100644
index 00000000000..cbb9e0efe36
--- /dev/null
+++ b/monai/data/video_dataset.py
@@ -0,0 +1,238 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tempfile
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
+
+import numpy as np
+from torch.utils.data import Dataset, IterableDataset
+
+from monai.utils.enums import ColorOrder
+from monai.utils.module import optional_import
+
+if TYPE_CHECKING:
+ import cv2
+
+ has_cv2 = True
+else:
+ cv2, has_cv2 = optional_import("cv2")
+
+__all__ = ["VideoDataset", "VideoFileDataset", "CameraDataset"]
+
+
+class SuppressStderr:
+ """Suppress stderr. Useful as OpenCV (and dependencies) can produce a lot of output."""
+
+ def __enter__(self):
+ self.errnull_file = open(os.devnull, "w")
+ self.old_stderr_fileno_undup = sys.stderr.fileno()
+ self.old_stderr_fileno = os.dup(sys.stderr.fileno())
+ self.old_stderr = sys.stderr
+ os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
+ sys.stderr = self.errnull_file
+ return self
+
+ def __exit__(self, *_):
+ sys.stderr = self.old_stderr
+ os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
+ os.close(self.old_stderr_fileno)
+ self.errnull_file.close()
+
+
+class VideoDataset:
+ def __init__(
+ self,
+ video_source: Union[str, int],
+ transform: Optional[Callable] = None,
+ max_num_frames: Optional[int] = None,
+ color_order: str = ColorOrder.RGB,
+ multiprocessing: bool = False,
+ channel_dim: int = 0,
+ ) -> None:
+ """
+ Base video dataset.
+
+ Args:
+ video_source: filename of video.
+ transform: transform to be applied to each frame.
+ max_num_frames: Max number of frames to iterate across. If `None` is passed,
+ then the dataset will iterate until the end of the file.
+ color_order: Color order to return frame. Default is RGB.
+ multiprocessing: If `True`, open the video source on the fly. This makes
+ things process-safe, which is useful when combined with a DataLoader
+ with `num_workers>0`. However, when using with `num_workers==0`, it
+ makes sense to use `multiprocessing=False`, as the source will then
+ only be opened once, at construction, which will be faster in those
+ circumstances.
+ channel_dim: OpenCV reads with the channel as the last dimension. Use this
+ flag to move it elsewhere. By default this is zero, so the channel
+ dimension is moved to the front.
+
+ Raises:
+ RuntimeError: OpenCV not installed.
+ NotImplementedError: Unknown color order.
+ """
+ if not has_cv2:
+ raise RuntimeError("OpenCV not installed.")
+ if color_order not in ColorOrder:
+ raise NotImplementedError
+
+ self.color_order = color_order
+ self.channel_dim = channel_dim
+ self.video_source = video_source
+ self.multiprocessing = multiprocessing
+ if not multiprocessing:
+ self.cap = self.open_video(video_source)
+ self.transform = transform
+ self.max_num_frames = max_num_frames
+
+ @staticmethod
+ def open_video(video_source: Union[str, int]):
+ """
+ Use OpenCV to open a video source from either file or capture device.
+
+ Args:
+ video_source: filename or index referring to capture device.
+
+ Raises:
+ RuntimeError: Source is a file but file not found.
+ RuntimeError: Failed to open source.
+ """
+ if isinstance(video_source, str) and not os.path.isfile(video_source):
+ raise RuntimeError("Video file does not exist: " + video_source)
+ with SuppressStderr():
+ cap = cv2.VideoCapture(video_source)
+ if not cap.isOpened():
+ raise RuntimeError(f"Failed to open video: {video_source}")
+ return cap
+
+ def _get_cap(self):
+ """Return the cap. If multiprocessing, create a new one. Else return the one from construction time."""
+ return self.open_video(self.video_source) if self.multiprocessing else self.cap
+
+ def get_fps(self) -> int:
+ """Get the FPS of the capture device."""
+ return self._get_cap().get(cv2.CAP_PROP_FPS) # type: ignore
+
+ def get_frame(self) -> Any:
+ """Get next frame. For a file, this will be the next frame, whereas for a camera
+ source, it will be the next available frame."""
+ ret, frame = self._get_cap().read()
+ if not ret:
+ raise RuntimeError("Failed to read frame.")
+ # Switch color order if desired
+ if self.color_order == ColorOrder.RGB:
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ # move channel dim
+ frame = np.moveaxis(frame, -1, self.channel_dim)
+ return self.transform(frame) if self.transform is not None else frame
+
+
+class VideoFileDataset(Dataset, VideoDataset):
+ """
+ Video dataset from file.
+
+ This class requires that OpenCV be installed.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ VideoDataset.__init__(self, *args, **kwargs)
+ num_frames = self.get_num_frames()
+ if self.max_num_frames is None or num_frames < self.max_num_frames:
+ self.max_num_frames = num_frames
+
+ @staticmethod
+ def get_available_codecs() -> Dict[str, str]:
+ """Try different codecs, see which are available.
+ Returns a dictionary with of available codecs with codecs as keys and file extensions as values."""
+ if not has_cv2:
+ return {}
+ all_codecs = {"mp4v": ".mp4", "X264": ".avi", "H264": ".mp4", "MP42": ".mp4", "MJPG": ".mjpeg", "DIVX": ".avi"}
+ codecs = {}
+ with SuppressStderr():
+ writer = cv2.VideoWriter()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ for codec, ext in all_codecs.items():
+ fname = os.path.join(tmp_dir, f"test{ext}")
+ fourcc = cv2.VideoWriter_fourcc(*codec)
+ noviderr = writer.open(fname, fourcc, 1, (10, 10))
+ if noviderr:
+ codecs[codec] = ext
+ writer.release()
+ return codecs
+
+ def get_num_frames(self) -> int:
+ """
+ Return the number of frames in a video file.
+
+ Raises:
+ RuntimeError: no frames found.
+ """
+ num_frames = int(self._get_cap().get(cv2.CAP_PROP_FRAME_COUNT))
+ if num_frames == 0:
+ raise RuntimeError("0 frames found")
+ return num_frames
+
+ def __len__(self):
+ return self.max_num_frames
+
+ def __getitem__(self, index: int) -> Any:
+ """
+ Fetch single data item from index.
+ """
+ if self.max_num_frames is not None and index >= self.max_num_frames:
+ raise IndexError
+ self._get_cap().set(cv2.CAP_PROP_POS_FRAMES, index)
+ return self.get_frame()
+
+
+class CameraDataset(IterableDataset, VideoDataset):
+ """
+ Video dataset from a capture device (e.g., webcam).
+
+ This class requires that OpenCV be installed.
+
+ Args:
+ video_source: index of capture device.
+ `get_num_devices` can be used to determine possible devices.
+ transform: transform to be applied to each frame.
+ max_num_frames: Max number of frames to iterate across. If `None` is passed,
+ then the dataset will iterate infinitely.
+
+ Raises:
+ RuntimeError: OpenCV not installed.
+ """
+
+ @staticmethod
+ def get_num_devices() -> int:
+ """Get number of possible devices detected by OpenCV that can be used for capture."""
+ if not has_cv2:
+ return 0
+ num_devices = 0
+ while True:
+ cap = cv2.VideoCapture(num_devices)
+ if not cap.read()[0]:
+ break
+ num_devices += 1
+ cap.release()
+ return num_devices
+
+ def __iter__(self):
+ frame_count = 0
+ while True:
+ frame = self.get_frame()
+ frame_count += 1
+ yield frame
+ if self.max_num_frames is not None:
+ if frame_count == self.max_num_frames:
+ break
diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py
index dc65bc8187e..d4b70f7f0a0 100644
--- a/monai/data/wsi_datasets.py
+++ b/monai/data/wsi_datasets.py
@@ -14,13 +14,15 @@
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import numpy as np
+import torch
from monai.data import Dataset
+from monai.data.meta_tensor import MetaTensor
from monai.data.utils import iter_patch_position
from monai.data.wsi_reader import BaseWSIReader, WSIReader
from monai.transforms import ForegroundMask, Randomizable, apply_transform
-from monai.utils import CommonKeys, ProbMapKeys, convert_to_dst_type, ensure_tuple_rep
-from monai.utils.enums import WSIPatchKeys
+from monai.utils import convert_to_dst_type, ensure_tuple_rep
+from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
@@ -42,10 +44,14 @@ class PatchWSIDataset(Dataset):
- a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
- - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
+ - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class
+ Returns:
+ dict: a dictionary of loaded image (in MetaTensor format) along with the labels (if requested).
+ {"image": MetaTensor, "label": torch.Tensor}
+
Note:
The input data has the following form as an example:
@@ -110,7 +116,7 @@ def _get_wsi_object(self, sample: Dict):
return self.wsi_object_dict[image_path]
def _get_label(self, sample: Dict):
- return np.array(sample[CommonKeys.LABEL], dtype=np.float32)
+ return torch.tensor(sample[CommonKeys.LABEL], dtype=torch.float32)
def _get_location(self, sample: Dict):
if self.center_location:
@@ -145,15 +151,18 @@ def _transform(self, index: int):
# Extract patch image and associated metadata
image, metadata = self._get_data(sample)
- output = {CommonKeys.IMAGE: image, CommonKeys.METADATA: metadata}
+
+ # Add additional metadata from sample
+ for key in self.additional_meta_keys:
+ metadata[key] = sample[key]
+
+ # Create MetaTensor output for image
+ output = {CommonKeys.IMAGE: MetaTensor(image, meta=metadata)}
# Include label in the output
if self.include_label:
output[CommonKeys.LABEL] = self._get_label(sample)
- for key in self.additional_meta_keys:
- metadata[key] = sample[key]
-
# Apply transforms and return it
return apply_transform(self.transform, output) if self.transform else output
@@ -177,7 +186,7 @@ class SlidingPatchWSIDataset(Randomizable, PatchWSIDataset):
- a string, it defines the backend of `monai.data.WSIReader`.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
- - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
+ - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
map_level: the resolution level at which the output map is created.
seed: random seed to randomly generate offsets. Defaults to 0.
@@ -322,7 +331,7 @@ class MaskedPatchWSIDataset(PatchWSIDataset):
- a string, it defines the backend of `monai.data.WSIReader`.
- a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
- - an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
+ - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class
@@ -333,7 +342,7 @@ class MaskedPatchWSIDataset(PatchWSIDataset):
[
{"image": "path/to/image1.tiff"},
- {"image": "path/to/image2.tiff", "patch_size": [20, 20], "patch_level": 2}
+ {"image": "path/to/image2.tiff", "size": [20, 20], "level": 2}
]
"""
diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py
index 0bb2de987c3..0d3924182c0 100644
--- a/monai/data/wsi_reader.py
+++ b/monai/data/wsi_reader.py
@@ -22,8 +22,9 @@
CuImage, _ = optional_import("cucim", name="CuImage")
OpenSlide, _ = optional_import("openslide", name="OpenSlide")
+TiffFile, _ = optional_import("tifffile", name="TiffFile")
-__all__ = ["BaseWSIReader", "WSIReader", "CuCIMWSIReader", "OpenSlideWSIReader"]
+__all__ = ["BaseWSIReader", "WSIReader", "CuCIMWSIReader", "OpenSlideWSIReader", "TiffFileWSIReader"]
class BaseWSIReader(ImageReader):
@@ -46,8 +47,8 @@ class BaseWSIReader(ImageReader):
- `read` reads a whole slide image object from a given file
- `get_size` returns the size of the whole slide image of a given wsi object at a given level.
- `get_level_count` returns the number of levels in the whole slide image
- - `get_patch` extracts and returns a patch image form the whole slide image
- - `get_metadata` extracts and returns metadata for a whole slide image and a specific patch.
+ - `_get_patch` extracts and returns a patch image form the whole slide image
+ - `_get_metadata` extracts and returns metadata for a whole slide image and a specific patch.
"""
@@ -55,7 +56,7 @@ class BaseWSIReader(ImageReader):
supported_suffixes: List[str] = []
backend = ""
- def __init__(self, level: int, channel_dim: int = 0, **kwargs):
+ def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs):
super().__init__()
self.level = level
self.channel_dim = channel_dim
@@ -63,7 +64,7 @@ def __init__(self, level: int, channel_dim: int = 0, **kwargs):
self.metadata: Dict[Any, Any] = {}
@abstractmethod
- def get_size(self, wsi, level: int) -> Tuple[int, int]:
+ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]:
"""
Returns the size (height, width) of the whole slide image at a given level.
@@ -86,13 +87,14 @@ def get_level_count(self, wsi) -> int:
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
@abstractmethod
- def get_downsample_ratio(self, wsi, level: int) -> float:
+ def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
@@ -103,7 +105,19 @@ def get_file_path(self, wsi) -> str:
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
@abstractmethod
- def get_patch(
+ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]:
+ """
+ Returns the micro-per-pixel resolution of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated
+
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+ @abstractmethod
+ def _get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
) -> np.ndarray:
"""
@@ -121,7 +135,7 @@ def get_patch(
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
- def get_metadata(
+ def _get_metadata(
self, wsi, patch: np.ndarray, location: Tuple[int, int], size: Tuple[int, int], level: int
) -> Dict:
"""
@@ -137,13 +151,15 @@ def get_metadata(
"""
if self.channel_dim >= len(patch.shape) or self.channel_dim < -len(patch.shape):
- ValueError(f"The desired channel_dim ({self.channel_dim}) is out of bound for image shape: {patch.shape}")
+ raise ValueError(
+ f"The desired channel_dim ({self.channel_dim}) is out of bound for image shape: {patch.shape}"
+ )
channel_dim: int = self.channel_dim + (len(patch.shape) if self.channel_dim < 0 else 0)
metadata: Dict = {
"backend": self.backend,
"original_channel_dim": channel_dim,
"spatial_shape": np.array(patch.shape[:channel_dim] + patch.shape[channel_dim + 1 :]),
- "num_patches": 1,
+ WSIPatchKeys.COUNT.value: 1,
WSIPatchKeys.PATH.value: self.get_file_path(wsi),
WSIPatchKeys.LOCATION.value: np.asarray(location),
WSIPatchKeys.SIZE.value: np.asarray(size),
@@ -206,7 +222,7 @@ def get_data(
raise ValueError(f"Patch size should be greater than zero, provided: patch size = {size}")
# Extract a patch or the entire image
- patch = self.get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
+ patch = self._get_patch(each_wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
# check if the image has three dimensions (2D + color)
if patch.ndim != 3:
@@ -228,7 +244,7 @@ def get_data(
f"{patch.shape[self.channel_dim]}. "
)
# Get patch-related metadata
- metadata: dict = self.get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)
+ metadata: dict = self._get_metadata(wsi=each_wsi, patch=patch, location=location, size=size, level=level)
# Create a list of patches and metadata
patch_list.append(patch)
metadata_list.append(metadata)
@@ -262,6 +278,7 @@ class WSIReader(BaseWSIReader):
backend: the name of backend whole slide image reader library, the default is cuCIM.
level: the level at which patches are extracted.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
+ num_workers: number of workers for multi-thread image loading (cucim backend only).
kwargs: additional arguments to be passed to the backend library
"""
@@ -269,13 +286,17 @@ class WSIReader(BaseWSIReader):
def __init__(self, backend="cucim", level: int = 0, channel_dim: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
self.backend = backend.lower()
- self.reader: Union[CuCIMWSIReader, OpenSlideWSIReader]
+ self.reader: Union[CuCIMWSIReader, OpenSlideWSIReader, TiffFileWSIReader]
if self.backend == "cucim":
self.reader = CuCIMWSIReader(level=level, channel_dim=channel_dim, **kwargs)
elif self.backend == "openslide":
self.reader = OpenSlideWSIReader(level=level, channel_dim=channel_dim, **kwargs)
+ elif self.backend == "tifffile":
+ self.reader = TiffFileWSIReader(level=level, channel_dim=channel_dim, **kwargs)
else:
- raise ValueError(f"The supported backends are cucim and openslide, '{self.backend}' was given.")
+ raise ValueError(
+ f"The supported backends are cucim, openslide, and tifffile but '{self.backend}' was given."
+ )
self.supported_suffixes = self.reader.supported_suffixes
def get_level_count(self, wsi) -> int:
@@ -288,33 +309,56 @@ def get_level_count(self, wsi) -> int:
"""
return self.reader.get_level_count(wsi)
- def get_size(self, wsi, level: int) -> Tuple[int, int]:
+ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]:
"""
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return self.reader.get_size(wsi, level)
- def get_downsample_ratio(self, wsi, level: int) -> float:
+ def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return self.reader.get_downsample_ratio(wsi, level)
def get_file_path(self, wsi) -> str:
"""Return the file path for the WSI object"""
return self.reader.get_file_path(wsi)
- def get_patch(
+ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]:
+ """
+ Returns the micro-per-pixel resolution of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+
+ return self.reader.get_mpp(wsi, level)
+
+ def _get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
) -> np.ndarray:
"""
@@ -330,7 +374,7 @@ def get_patch(
mode: the output image mode, 'RGB' or 'RGBA'
"""
- return self.reader.get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
+ return self.reader._get_patch(wsi=wsi, location=location, size=size, level=level, dtype=dtype, mode=mode)
def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
@@ -356,6 +400,7 @@ class CuCIMWSIReader(BaseWSIReader):
level: the whole slide image level at which the image is extracted. (default=0)
This is overridden if the level argument is provided in `get_data`.
channel_dim: the desired dimension for color channel. Default to 0 (channel first).
+ num_workers: number of workers for multi-thread image loading
kwargs: additional args for `cucim.CuImage` module:
https://github.com/rapidsai/cucim/blob/main/cpp/include/cucim/cuimage.h
@@ -364,8 +409,9 @@ class CuCIMWSIReader(BaseWSIReader):
supported_suffixes = ["tif", "tiff", "svs"]
backend = "cucim"
- def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs):
+ def __init__(self, level: int = 0, channel_dim: int = 0, num_workers: int = 0, **kwargs):
super().__init__(level, channel_dim, **kwargs)
+ self.num_workers = num_workers
@staticmethod
def get_level_count(wsi) -> int:
@@ -378,34 +424,57 @@ def get_level_count(wsi) -> int:
"""
return wsi.resolutions["level_count"] # type: ignore
- @staticmethod
- def get_size(wsi, level: int) -> Tuple[int, int]:
+ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]:
"""
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return (wsi.resolutions["level_dimensions"][level][1], wsi.resolutions["level_dimensions"][level][0])
- @staticmethod
- def get_downsample_ratio(wsi, level: int) -> float:
+ def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return wsi.resolutions["level_downsamples"][level] # type: ignore
- def get_file_path(self, wsi) -> str:
+ @staticmethod
+ def get_file_path(wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi.path))
+ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]:
+ """
+ Returns the micro-per-pixel resolution of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+
+ factor = float(wsi.resolutions["level_downsamples"][level])
+ return (wsi.metadata["cucim"]["spacing"][1] * factor, wsi.metadata["cucim"]["spacing"][0] * factor)
+
def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Read whole slide image objects from given file or list of files.
@@ -430,7 +499,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
return wsi_list if len(filenames) > 1 else wsi_list[0]
- def get_patch(
+ def _get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
) -> np.ndarray:
"""
@@ -448,7 +517,9 @@ def get_patch(
"""
# Extract a patch or the entire image
# (reverse the order of location and size to become WxH for cuCIM)
- patch: np.ndarray = wsi.read_region(location=location[::-1], size=size[::-1], level=level)
+ patch: np.ndarray = wsi.read_region(
+ location=location[::-1], size=size[::-1], level=level, num_workers=self.num_workers
+ )
# Convert to numpy
patch = np.asarray(patch, dtype=dtype)
@@ -484,9 +555,6 @@ class OpenSlideWSIReader(BaseWSIReader):
supported_suffixes = ["tif", "tiff", "svs"]
backend = "openslide"
- def __init__(self, level: int = 0, channel_dim: int = 0, **kwargs):
- super().__init__(level, channel_dim, **kwargs)
-
@staticmethod
def get_level_count(wsi) -> int:
"""
@@ -498,34 +566,68 @@ def get_level_count(wsi) -> int:
"""
return wsi.level_count # type: ignore
- @staticmethod
- def get_size(wsi, level: int) -> Tuple[int, int]:
+ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]:
"""
Returns the size (height, width) of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return (wsi.level_dimensions[level][1], wsi.level_dimensions[level][0])
- @staticmethod
- def get_downsample_ratio(wsi, level: int) -> float:
+ def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float:
"""
Returns the down-sampling ratio of the whole slide image at a given level.
Args:
wsi: a whole slide image object loaded from a file
- level: the level number where the size is calculated
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
"""
+ if level is None:
+ level = self.level
+
return wsi.level_downsamples[level] # type: ignore
- def get_file_path(self, wsi) -> str:
+ @staticmethod
+ def get_file_path(wsi) -> str:
"""Return the file path for the WSI object"""
return str(abspath(wsi._filename))
+ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]:
+ """
+ Returns the micro-per-pixel resolution of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+ unit = wsi.properties["tiff.ResolutionUnit"]
+ if unit == "centimeter":
+ factor = 10000.0
+ elif unit == "milimeter":
+ factor = 1000.0
+ elif unit == "micrometer":
+ factor = 1.0
+ elif unit == "inch":
+ factor = 25400.0
+ else:
+ raise ValueError(f"The resolution unit is not a valid tiff resolution: {unit}")
+
+ factor *= wsi.level_downsamples[level]
+ return (factor / float(wsi.properties["tiff.YResolution"]), factor / float(wsi.properties["tiff.XResolution"]))
+
def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
"""
Read whole slide image objects from given file or list of files.
@@ -549,7 +651,7 @@ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
return wsi_list if len(filenames) > 1 else wsi_list[0]
- def get_patch(
+ def _get_patch(
self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
) -> np.ndarray:
"""
@@ -579,3 +681,159 @@ def get_patch(
patch = np.moveaxis(patch, -1, self.channel_dim)
return patch
+
+
+@require_pkg(pkg_name="tifffile")
+class TiffFileWSIReader(BaseWSIReader):
+ """
+ Read whole slide images and extract patches using TiffFile library.
+
+ Args:
+ level: the whole slide image level at which the image is extracted. (default=0)
+ This is overridden if the level argument is provided in `get_data`.
+ channel_dim: the desired dimension for color channel. Default to 0 (channel first).
+ kwargs: additional args for `tifffile.TiffFile` module.
+
+ """
+
+ supported_suffixes = ["tif", "tiff", "svs"]
+ backend = "tifffile"
+
+ @staticmethod
+ def get_level_count(wsi) -> int:
+ """
+ Returns the number of levels in the whole slide image.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+
+ """
+ return len(wsi.pages)
+
+ def get_size(self, wsi, level: Optional[int] = None) -> Tuple[int, int]:
+ """
+ Returns the size (height, width) of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+
+ return (wsi.pages[level].imagelength, wsi.pages[level].imagewidth)
+
+ def get_downsample_ratio(self, wsi, level: Optional[int] = None) -> float:
+ """
+ Returns the down-sampling ratio of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+
+ return float(wsi.pages[0].imagelength) / float(wsi.pages[level].imagelength)
+
+ @staticmethod
+ def get_file_path(wsi) -> str:
+ """Return the file path for the WSI object"""
+ return str(abspath(wsi.filehandle.path))
+
+ def get_mpp(self, wsi, level: Optional[int] = None) -> Tuple[float, float]:
+ """
+ Returns the micro-per-pixel resolution of the whole slide image at a given level.
+
+ Args:
+ wsi: a whole slide image object loaded from a file
+ level: the level number where the size is calculated. If not provided the default level (from `self.level`)
+ will be used.
+
+ """
+ if level is None:
+ level = self.level
+
+ unit = wsi.pages[level].tags["ResolutionUnit"].value
+ if unit == unit.CENTIMETER:
+ factor = 10000.0
+ elif unit == unit.MILLIMETER:
+ factor = 1000.0
+ elif unit == unit.MICROMETER:
+ factor = 1.0
+ elif unit == unit.INCH:
+ factor = 25400.0
+ else:
+ raise ValueError(f"The resolution unit is not a valid tiff resolution or missing: {unit.name}")
+
+ # Here x and y resolutions are rational numbers so each of them is represented by a tuple.
+ yres = wsi.pages[level].tags["YResolution"].value
+ xres = wsi.pages[level].tags["XResolution"].value
+ return (factor * yres[1] / yres[0], factor * xres[1] / xres[0])
+
+ def read(self, data: Union[Sequence[PathLike], PathLike, np.ndarray], **kwargs):
+ """
+ Read whole slide image objects from given file or list of files.
+
+ Args:
+ data: file name or a list of file names to read.
+ kwargs: additional args that overrides `self.kwargs` for existing keys.
+
+ Returns:
+ whole slide image object or list of such objects
+
+ """
+ wsi_list: List = []
+
+ filenames: Sequence[PathLike] = ensure_tuple(data)
+ kwargs_ = self.kwargs.copy()
+ kwargs_.update(kwargs)
+ for filename in filenames:
+ wsi = TiffFile(filename, **kwargs_)
+ wsi_list.append(wsi)
+
+ return wsi_list if len(filenames) > 1 else wsi_list[0]
+
+ def _get_patch(
+ self, wsi, location: Tuple[int, int], size: Tuple[int, int], level: int, dtype: DtypeLike, mode: str
+ ) -> np.ndarray:
+ """
+ Extracts and returns a patch image form the whole slide image.
+
+ Args:
+ wsi: a whole slide image object loaded from a file or a lis of such objects
+ location: (top, left) tuple giving the top left pixel in the level 0 reference frame. Defaults to (0, 0).
+ size: (height, width) tuple giving the patch size at the given level (`level`).
+ If None, it is set to the full image size at the given level.
+ level: the level number. Defaults to 0
+ dtype: the data type of output image
+ mode: the output image mode, 'RGB' or 'RGBA'
+
+ """
+ # Load the entire image
+ wsi_image: np.ndarray = wsi.asarray(level=level).astype(dtype)
+ if len(wsi_image.shape) < 3:
+ wsi_image = wsi_image[..., None]
+
+ # Extract patch
+ downsampling_ratio = self.get_downsample_ratio(wsi=wsi, level=level)
+ location_ = [round(location[i] / downsampling_ratio) for i in range(len(location))]
+ patch = wsi_image[location_[0] : location_[0] + size[0], location_[1] : location_[1] + size[1], :].copy()
+
+ # Make the channel to desired dimensions
+ patch = np.moveaxis(patch, -1, self.channel_dim)
+
+ # Check if the color channel is 3 (RGB) or 4 (RGBA)
+ if mode in "RGB":
+ if patch.shape[self.channel_dim] not in [3, 4]:
+ raise ValueError(
+ f"The image is expected to have three or four color channels in '{mode}' mode but has "
+ f"{patch.shape[self.channel_dim]}. "
+ )
+ patch = patch[:3]
+
+ return patch
diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py
index 88f094c7322..b6e54a6c4e8 100644
--- a/monai/engines/__init__.py
+++ b/monai/engines/__init__.py
@@ -13,7 +13,6 @@
from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
from .trainer import GanTrainer, SupervisedTrainer, Trainer
from .utils import (
- GanKeys,
IterationEvents,
PrepareBatch,
PrepareBatchDefault,
@@ -24,4 +23,4 @@
engine_apply_transform,
get_devices_spec,
)
-from .workflow import BaseWorkflow, Workflow
+from .workflow import Workflow
diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py
index 7999bb9bd6e..2bf172418a2 100644
--- a/monai/engines/evaluator.py
+++ b/monai/engines/evaluator.py
@@ -22,8 +22,9 @@
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode, train_mode
from monai.transforms import Transform
-from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
+from monai.utils import ForwardMode, deprecated, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
+from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import look_up_option
if TYPE_CHECKING:
@@ -146,7 +147,28 @@ def run(self, global_epoch: int = 1) -> None:
self.state.iteration = 0
super().run()
- def get_validation_stats(self) -> dict[str, float]:
+ def get_stats(self, *vars):
+ """
+ Get the statistics information of the validation process.
+ Default to return the `rank`, `best_validation_epoch` and `best_validation_metric`.
+
+ Args:
+ vars: except for the default stats, other variables name in the `self.state` to return,
+ will use the variable name as the key and the state content as the value.
+ if the variable doesn't exist, default value is `None`.
+
+ """
+ stats = {
+ ESKeys.RANK: self.state.rank,
+ ESKeys.BEST_VALIDATION_EPOCH: self.state.best_metric_epoch,
+ ESKeys.BEST_VALIDATION_METRIC: self.state.best_metric,
+ }
+ for k in vars:
+ stats[k] = getattr(self.state, k, None)
+ return stats
+
+ @deprecated(since="0.9", msg_suffix="please use the `get_stats()` API instead.")
+ def get_validation_stats(self):
return {"best_validation_metric": self.state.best_metric, "best_validation_epoch": self.state.best_metric_epoch}
diff --git a/monai/engines/multi_gpu_supervised_trainer.py b/monai/engines/multi_gpu_supervised_trainer.py
index 0433617649c..a9171d0f506 100644
--- a/monai/engines/multi_gpu_supervised_trainer.py
+++ b/monai/engines/multi_gpu_supervised_trainer.py
@@ -31,8 +31,12 @@
from ignite.engine import Engine
from ignite.metrics import Metric
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
- Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
+ Metric, _ = optional_import(
+ "ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="decorator"
+ )
__all__ = ["create_multigpu_supervised_trainer", "create_multigpu_supervised_evaluator"]
diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py
index 2b7a1acd2a4..12007b6294e 100644
--- a/monai/engines/trainer.py
+++ b/monai/engines/trainer.py
@@ -18,18 +18,13 @@
from torch.utils.data import DataLoader
from monai.config import IgniteInfo
-from monai.engines.utils import (
- GanKeys,
- IterationEvents,
- default_make_latent,
- default_metric_cmp_fn,
- default_prepare_batch,
-)
+from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
-from monai.utils import min_version, optional_import
+from monai.utils import GanKeys, deprecated, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
+from monai.utils.enums import EngineStatsKeys as ESKeys
if TYPE_CHECKING:
from ignite.engine import Engine, EventEnum
@@ -57,8 +52,31 @@ def run(self) -> None:
self.scaler = torch.cuda.amp.GradScaler() if self.amp else None
super().run()
- def get_train_stats(self) -> dict[str, float]:
- return {"total_epochs": self.state.max_epochs, "total_iterations": self.state.epoch_length}
+ def get_stats(self, *vars):
+ """
+ Get the statistics information of the training process.
+ Default to return the `rank`, `current_epoch`, `current_iteration`, `total_epochs`, `total_iterations`.
+
+ Args:
+ vars: except for the default stats, other variables name in the `self.state` to return,
+ will use the variable name as the key and the state content as the value.
+ if the variable doesn't exist, default value is `None`.
+
+ """
+ stats = {
+ ESKeys.RANK: self.state.rank,
+ ESKeys.CURRENT_EPOCH: self.state.epoch,
+ ESKeys.CURRENT_ITERATION: self.state.iteration,
+ ESKeys.TOTAL_EPOCHS: self.state.max_epochs,
+ ESKeys.TOTAL_ITERATIONS: self.state.epoch_length,
+ }
+ for k in vars:
+ stats[k] = getattr(self.state, k, None)
+ return stats
+
+ @deprecated(since="0.9", msg_suffix="please use the `get_stats()` API instead.")
+ def get_train_stats(self):
+ return self.get_stats()
class SupervisedTrainer(Trainer):
diff --git a/monai/engines/utils.py b/monai/engines/utils.py
index 8f3a57bedac..22a0e1de3d8 100644
--- a/monai/engines/utils.py
+++ b/monai/engines/utils.py
@@ -17,16 +17,17 @@
from monai.config import IgniteInfo
from monai.transforms import apply_transform
from monai.utils import ensure_tuple, min_version, optional_import
-from monai.utils.enums import CommonKeys
+from monai.utils.enums import CommonKeys, GanKeys
if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
- EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
+ EventEnum, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
+ )
__all__ = [
"IterationEvents",
- "GanKeys",
"get_devices_spec",
"default_prepare_batch",
"PrepareBatch",
@@ -59,19 +60,6 @@ class IterationEvents(EventEnum):
INNER_ITERATION_COMPLETED = "inner_iteration_completed"
-class GanKeys:
- """
- A set of common keys for generative adversarial networks.
-
- """
-
- REALS = "reals"
- FAKES = "fakes"
- LATENTS = "latents"
- GLOSS = "g_loss"
- DLOSS = "d_loss"
-
-
def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[torch.device]:
"""
Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available.
@@ -104,30 +92,54 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t
def default_prepare_batch(
- batchdata: Dict[str, torch.Tensor],
+ batchdata: Union[Dict[str, torch.Tensor], torch.Tensor, Sequence[torch.Tensor]],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
) -> Union[Tuple[torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
"""
Default function to prepare the data for current iteration.
- Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
- https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
- `kwargs` supports other args for `Tensor.to()` API.
+
+ The input `batchdata` is either a single tensor, a pair of tensors, or a dictionary of data. In the first case the
+ return value is the tensor and None, in the second case the return value is the two tensors, and in the dictionary
+ case the return value depends on what keys are present. if `CommonKeys.IMAGE` and `CommonKeys.LABEL` are present
+ then the tensors they key to are returned, if only `CommonKeys.IMAGE` is present that tensor and None is returned.
+ If `CommonKeys.REALS` is present this is returned with None. All returned tensors are moved to the given device
+ using the given non-blocking argument before being returned.
+
+ This function implemenets the expected API for a `prepare_batch` callable in Ignite:
+ https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html
+
+ Args:
+ batchdata: input batch data which is either a single tensor, a pair, or a dictionary
+ device: device to move every returned tensor to
+ non_blocking: equivalent argument for `Tensor.to`
+ kwargs: further arguments for `Tensor.to`
Returns:
image, label(optional).
-
"""
if not isinstance(batchdata, dict):
- raise AssertionError("default prepare_batch expects dictionary input data.")
+ if isinstance(batchdata, torch.Tensor):
+ return batchdata.to(device=device, non_blocking=non_blocking, **kwargs), None
+ elif len(batchdata) == 2:
+ image, label = batchdata
+ return (
+ image.to(device=device, non_blocking=non_blocking, **kwargs),
+ label.to(device=device, non_blocking=non_blocking, **kwargs),
+ )
+
+ raise AssertionError("Default prepare_batch expects a single tensor, a tensor pair, or dictionary input data.")
+
if isinstance(batchdata.get(CommonKeys.LABEL), torch.Tensor):
return (
batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs),
batchdata[CommonKeys.LABEL].to(device=device, non_blocking=non_blocking, **kwargs),
)
+
if GanKeys.REALS in batchdata:
return batchdata[GanKeys.REALS].to(device=device, non_blocking=non_blocking, **kwargs)
+
return batchdata[CommonKeys.IMAGE].to(device=device, non_blocking=non_blocking, **kwargs), None
@@ -138,7 +150,6 @@ class PrepareBatch(ABC):
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
-
"""
@abstractmethod
@@ -154,13 +165,12 @@ def __call__(
class PrepareBatchDefault(PrepareBatch):
"""
- Default prepare batch method to return `image` and `label` only,
- it's to be consistent with `default_prepare_batch` API.
+ This wraps `default_prepare_batch` to return `image` and `label` only, so is consistent with its API.
"""
def __call__(
self,
- batchdata: Dict[str, torch.Tensor],
+ batchdata: Union[Dict[str, torch.Tensor], torch.Tensor, Sequence[torch.Tensor]],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
**kwargs,
@@ -176,16 +186,15 @@ def __call__(
class PrepareBatchExtraInput(PrepareBatch):
"""
- Customized prepare_batch for trainer or evaluator that support extra input data for network.
- Extra items are specified by the `extra_keys` parameter.
+ Customized prepare batch callable for trainers or evaluators which support extra input data for the network.
+ Extra items are specified by the `extra_keys` parameter and are extracted from the input dictionary (ie. the batch).
+ This uses `default_prepare_batch` but requires dictionary inputs.
Args:
- extra_keys: if a string or list provided, every item is the key of extra data in current batch,
- and will pass the extra data to the `network(*args)` in order.
- If a dictionary is provided, every `{k, v}` pair is the key of extra data in current batch,
- `k` is the param name in network, `v` is the key of extra data in current batch,
- and will pass the `{k1: batch[v1], k2: batch[v2], ...}` as kwargs to the network.
-
+ extra_keys: If a string or sequence of strings is provided, values from the input dictionary are extracted from
+ those keys and passed to the nework as extra positional arguments. If a dictionary is provided, every pair
+ `(k, v)` in that dictionary will become a new keyword argument assigning to `k` the value in the input
+ dictionary keyed to `v`.
"""
def __init__(self, extra_keys: Union[str, Sequence[str], Dict[str, str]]) -> None:
@@ -202,7 +211,6 @@ def __call__(
Args `batchdata`, `device`, `non_blocking` refer to the ignite API:
https://pytorch.org/ignite/v0.4.8/generated/ignite.engine.create_supervised_trainer.html.
`kwargs` supports other args for `Tensor.to()` API.
-
"""
image, label = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
args_ = list()
@@ -210,9 +218,11 @@ def __call__(
def _get_data(key: str):
data = batchdata[key]
- return (
- data.to(device=device, non_blocking=non_blocking, **kwargs) if isinstance(data, torch.Tensor) else data
- )
+
+ if isinstance(data, torch.Tensor):
+ data = data.to(device=device, non_blocking=non_blocking, **kwargs)
+
+ return data
if isinstance(self.extra_keys, (str, list, tuple)):
for k in ensure_tuple(self.extra_keys):
diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py
index 8349ff82abb..28f2430a878 100644
--- a/monai/engines/workflow.py
+++ b/monai/engines/workflow.py
@@ -10,7 +10,6 @@
# limitations under the License.
import warnings
-from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union
import torch
@@ -25,7 +24,7 @@
from .utils import engine_apply_transform
-IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+IgniteEngine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="")
State, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "State")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
@@ -33,21 +32,15 @@
from ignite.engine import Engine, EventEnum
from ignite.metrics import Metric
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
- Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
- EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
-
-
-class BaseWorkflow(ABC):
- """
- Base class for any MONAI style workflow.
- `run()` is designed to execute the train, evaluation or inference logic.
-
- """
-
- @abstractmethod
- def run(self, *args, **kwargs):
- raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
+ Metric, _ = optional_import(
+ "ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="decorator"
+ )
+ EventEnum, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="decorator"
+ )
class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import
@@ -133,7 +126,7 @@ def __init__(
else:
super().__init__(self._iteration)
if not isinstance(device, torch.device):
- raise TypeError(f"device must be a torch.device but is {type(device).__name__}.")
+ raise TypeError(f"Device must be a torch.device but is {type(device).__name__}.")
if isinstance(data_loader, DataLoader):
sampler = data_loader.__dict__["sampler"]
@@ -147,7 +140,7 @@ def set_sampler_epoch(engine: Engine):
epoch_length = len(data_loader)
else:
if epoch_length is None:
- raise ValueError("if data_loader is not PyTorch DataLoader, must specify the epoch_length.")
+ raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.")
# set all sharable data for the workflow based on Ignite engine.state
self.state = State(
@@ -180,7 +173,7 @@ def set_sampler_epoch(engine: Engine):
event_names = [IterationEvents] # type: ignore
else:
if not isinstance(event_names, list):
- raise ValueError("event_names must be a list or string or EventEnum.")
+ raise ValueError("`event_names` must be a list or string or EventEnum.")
event_names += [IterationEvents] # type: ignore
for name in event_names:
if isinstance(name, str):
@@ -188,7 +181,7 @@ def set_sampler_epoch(engine: Engine):
elif issubclass(name, EventEnum): # type: ignore
self.register_events(*name, event_to_attr=event_to_attr)
else:
- raise ValueError("event_names must be a list or string or EventEnum.")
+ raise ValueError("`event_names` must be a list or string or EventEnum.")
if decollate:
self._register_decollate()
@@ -239,12 +232,12 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None):
"""
if not isinstance(k_metric, dict):
- raise TypeError(f"key_metric must be None or a dict but is {type(k_metric).__name__}.")
+ raise TypeError(f"`key_metric` must be None or a dict but is {type(k_metric).__name__}.")
self.state.key_metric_name = list(k_metric.keys())[0]
metrics = dict(k_metric)
if add_metrics is not None and len(add_metrics) > 0:
if not isinstance(add_metrics, dict):
- raise TypeError(f"additional metrics must be None or a dict but is {type(add_metrics).__name__}.")
+ raise TypeError(f"Additional metrics must be None or a dict but is {type(add_metrics).__name__}.")
metrics.update(add_metrics)
for name, metric in metrics.items():
metric.attach(self, name)
@@ -256,12 +249,14 @@ def _compare_metrics(engine: Workflow) -> None:
current_val_metric = engine.state.metrics[key_metric_name]
if not is_scalar(current_val_metric):
warnings.warn(
- "key metric is not a scalar value, skip the metric comparison with the current best metric."
- "please set other metrics as the key metric, or change the `reduction` mode to 'mean'."
+ "Key metric is not a scalar value, skip the metric comparison with the current best metric."
+ "Please set other metrics as the key metric, or change the `reduction` mode to 'mean'."
)
return
- if self.metric_cmp_fn(current_val_metric, engine.state.best_metric):
+ if engine.state.best_metric_epoch == -1 or self.metric_cmp_fn(
+ current_val_metric, engine.state.best_metric
+ ):
self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}")
engine.state.best_metric = current_val_metric
engine.state.best_metric_epoch = engine.state.epoch
@@ -278,12 +273,11 @@ def _register_handlers(self, handlers: Sequence):
def run(self) -> None:
"""
Execute training, validation or evaluation based on Ignite Engine.
-
"""
if self.state.epoch_length == 0:
warnings.warn(
"`dataloader` is empty or the specified `epoch_length` is 0, skip the `run`."
- " if running distributed training, the program may hang in `all-gather`, `all-reduce`, etc."
+ " If running distributed training, the program may hang in `all-gather`, `all-reduce`, etc."
" because not all the ranks run the same computation logic."
)
return
@@ -303,3 +297,14 @@ def _iteration(self, engine, batchdata: Dict[str, torch.Tensor]):
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+ def get_stats(self, *vars):
+ """
+ Get the statistics information of the workflow process.
+
+ Args:
+ vars: variables name in the `self.state`, will use the variable name as the key
+ and the state content as the value. if the variable doesn't exist, default value is `None`.
+
+ """
+ return {k: getattr(self.state, k, None) for k in vars}
diff --git a/monai/fl/__init__.py b/monai/fl/__init__.py
new file mode 100644
index 00000000000..1e97f894078
--- /dev/null
+++ b/monai/fl/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/fl/client/__init__.py b/monai/fl/client/__init__.py
new file mode 100644
index 00000000000..7acb82c635e
--- /dev/null
+++ b/monai/fl/client/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .client_algo import BaseClient, ClientAlgo, ClientAlgoStats
+from .monai_algo import MonaiAlgo, MonaiAlgoStats
diff --git a/monai/fl/client/client_algo.py b/monai/fl/client/client_algo.py
new file mode 100644
index 00000000000..9c54f2891b6
--- /dev/null
+++ b/monai/fl/client/client_algo.py
@@ -0,0 +1,152 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+from monai.fl.utils.exchange_object import ExchangeObject
+
+
+class BaseClient:
+ """
+ Provide an abstract base class to allow the client to return summary statistics of the data.
+
+ To define a new stats script, subclass this class and implement the
+ following abstract methods::
+
+ - self.get_data_stats()
+
+ initialize(), abort(), and finalize() -- inherited from `ClientAlgoStats`; can be optionally be implemented
+ to help with lifecycle management of the class object.
+ """
+
+ def initialize(self, extra: Optional[dict] = None):
+ """
+ Call to initialize the ClientAlgo class.
+
+ Args:
+ extra: optional extra information, e.g. dict of `ExtraItems.CLIENT_NAME` and/or `ExtraItems.APP_ROOT`.
+ """
+ pass
+
+ def finalize(self, extra: Optional[dict] = None):
+ """
+ Call to finalize the ClientAlgo class.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+ """
+ pass
+
+ def abort(self, extra: Optional[dict] = None):
+ """
+ Call to abort the ClientAlgo training or evaluation.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+ """
+
+ pass
+
+
+class ClientAlgoStats(BaseClient):
+ def get_data_stats(self, extra: Optional[dict] = None) -> ExchangeObject:
+ """
+ Get summary statistics about the local data.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+ For example, requested statistics.
+
+ Returns:
+
+ ExchangeObject: summary statistics.
+
+ Extra dict example::
+
+ requested_stats = {
+ FlStatistics.STATISTICS: metrics,
+ FlStatistics.NUM_OF_BINS: num_of_bins,
+ FlStatistics.BIN_RANGES: bin_ranges
+ }
+
+ Returned ExchangeObject example::
+
+ ExchangeObject(
+ statistics = {...}
+ )
+
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+
+class ClientAlgo(ClientAlgoStats):
+ """
+ Provide an abstract base class for defining algo to run on any platform.
+ To define a new algo script, subclass this class and implement the
+ following abstract methods:
+
+ - self.train()
+ - self.get_weights()
+ - self.evaluate()
+ - self.get_data_stats() (optional, inherited from `ClientAlgoStats`)
+
+ initialize(), abort(), and finalize() - inherited from `ClientAlgoStats` - can be optionally be implemented
+ to help with lifecycle management of the class object.
+ """
+
+ def train(self, data: ExchangeObject, extra: Optional[dict] = None) -> None:
+ """
+ Train network and produce new network from train data.
+
+ Args:
+ data: ExchangeObject containing current network weights to base training on.
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ None
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+ def get_weights(self, extra: Optional[dict] = None) -> ExchangeObject:
+ """
+ Get current local weights or weight differences.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ ExchangeObject: current local weights or weight differences.
+
+ `ExchangeObject` example:
+
+ .. code-block:: python
+
+ ExchangeObject(
+ weights = self.trainer.network.state_dict(),
+ optim = None, # could be self.optimizer.state_dict()
+ weight_type = WeightType.WEIGHTS
+ )
+
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
+
+ def evaluate(self, data: ExchangeObject, extra: Optional[dict] = None) -> ExchangeObject:
+ """
+ Get evaluation metrics on test data.
+
+ Args:
+ data: ExchangeObject with network weights to use for evaluation.
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ metrics: ExchangeObject with evaluation metrics.
+ """
+ raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py
new file mode 100644
index 00000000000..f18e7faa635
--- /dev/null
+++ b/monai/fl/client/monai_algo.py
@@ -0,0 +1,713 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+
+import monai
+from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
+from monai.auto3dseg import SegSummarizer
+from monai.bundle import ConfigParser
+from monai.bundle.config_item import ConfigComponent, ConfigItem
+from monai.fl.client import ClientAlgo, ClientAlgoStats
+from monai.fl.utils.constants import (
+ BundleKeys,
+ ExtraItems,
+ FiltersType,
+ FlPhase,
+ FlStatistics,
+ ModelType,
+ RequiredBundleKeys,
+ WeightType,
+)
+from monai.fl.utils.exchange_object import ExchangeObject
+from monai.networks.utils import copy_model_state, get_state_dict
+from monai.utils import min_version, require_pkg
+from monai.utils.enums import DataStatsKeys
+
+logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(asctime)s - %(message)s")
+
+
+def convert_global_weights(global_weights, local_var_dict):
+ """Helper function to convert global weights to local weights format"""
+ # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation.
+ model_keys = global_weights.keys()
+ n_converted = 0
+ for var_name in local_var_dict:
+ if var_name in model_keys:
+ weights = global_weights[var_name]
+ try:
+ # reshape global weights to compute difference later on
+ weights = torch.reshape(torch.as_tensor(weights), local_var_dict[var_name].shape)
+ # update the local dict
+ local_var_dict[var_name] = weights
+ n_converted += 1
+ except Exception as e:
+ raise ValueError(f"Convert weight from {var_name} failed.") from e
+ return local_var_dict, n_converted
+
+
+def compute_weight_diff(global_weights, local_var_dict):
+ if global_weights is None:
+ raise ValueError("Cannot compute weight differences if `global_weights` is None!")
+ if local_var_dict is None:
+ raise ValueError("Cannot compute weight differences if `local_var_dict` is None!")
+ # compute delta model, global model has the primary key set
+ weight_diff = {}
+ for name in global_weights:
+ if name not in local_var_dict:
+ continue
+ # returned weight diff will be on the cpu
+ weight_diff[name] = local_var_dict[name].cpu() - global_weights[name].cpu()
+ if torch.any(torch.isnan(weight_diff[name])):
+ raise ValueError(f"Weights for {name} became NaN...")
+ return weight_diff
+
+
+def check_bundle_config(parser):
+ for k in RequiredBundleKeys:
+ if parser.get(k, None) is None:
+ raise KeyError(f"Bundle config misses required key `{k}`")
+
+
+def disable_ckpt_loaders(parser):
+ if BundleKeys.VALIDATE_HANDLERS in parser:
+ for h in parser[BundleKeys.VALIDATE_HANDLERS]:
+ if ConfigComponent.is_instantiable(h):
+ if "CheckpointLoader" in h["_target_"]:
+ h["_disabled_"] = True
+
+
+class MonaiAlgoStats(ClientAlgoStats):
+ """
+ Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.
+
+ Args:
+ bundle_root: path of bundle.
+ config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;
+ defaults to "configs/train.json".
+ config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
+ histogram_only: whether to only compute histograms. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ bundle_root: str,
+ config_train_filename: Optional[Union[str, list]] = "configs/train.json",
+ config_filters_filename: Optional[Union[str, list]] = None,
+ train_data_key: Optional[str] = BundleKeys.TRAIN_DATA,
+ eval_data_key: Optional[str] = BundleKeys.VALID_DATA,
+ data_stats_transform_list: Optional[list] = None,
+ histogram_only: bool = False,
+ ):
+ self.logger = logging.getLogger(self.__class__.__name__)
+ self.bundle_root = bundle_root
+ self.config_train_filename = config_train_filename
+ self.config_filters_filename = config_filters_filename
+ self.train_data_key = train_data_key
+ self.eval_data_key = eval_data_key
+ self.data_stats_transform_list = data_stats_transform_list
+ self.histogram_only = histogram_only
+
+ self.client_name = None
+ self.app_root = None
+ self.train_parser = None
+ self.filter_parser = None
+ self.post_statistics_filters = None
+ self.phase = FlPhase.IDLE
+ self.dataset_root = None
+
+ def initialize(self, extra=None):
+ """
+ Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.
+
+ Args:
+ extra: Dict with additional information that should be provided by FL system,
+ i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
+
+ """
+ if extra is None:
+ extra = {}
+ self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
+ self.logger.info(f"Initializing {self.client_name} ...")
+
+ # FL platform needs to provide filepath to configuration files
+ self.app_root = extra.get(ExtraItems.APP_ROOT, "")
+
+ # Read bundle config files
+ self.bundle_root = os.path.join(self.app_root, self.bundle_root)
+
+ config_train_files = self._add_config_files(self.config_train_filename)
+ config_filter_files = self._add_config_files(self.config_filters_filename)
+
+ # Parse
+ self.train_parser = ConfigParser()
+ self.filter_parser = ConfigParser()
+ if len(config_train_files) > 0:
+ self.train_parser.read_config(config_train_files)
+ check_bundle_config(self.train_parser)
+ if len(config_filter_files) > 0:
+ self.filter_parser.read_config(config_filter_files)
+
+ # override some config items
+ self.train_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root
+
+ # Get data location
+ self.dataset_root = self.train_parser.get_parsed_content(
+ BundleKeys.DATASET_DIR, default=ConfigItem(None, BundleKeys.DATASET_DIR)
+ )
+
+ # Get filters
+ self.post_statistics_filters = self.filter_parser.get_parsed_content(
+ FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)
+ )
+
+ self.logger.info(f"Initialized {self.client_name}.")
+
+ def get_data_stats(self, extra: Optional[dict] = None) -> ExchangeObject:
+ """
+ Returns summary statistics about the local data.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ stats: ExchangeObject with summary statistics.
+
+ """
+
+ if self.dataset_root:
+ self.phase = FlPhase.GET_DATA_STATS
+ self.logger.info(f"Computing statistics on {self.dataset_root}")
+
+ if FlStatistics.HIST_BINS not in extra:
+ raise ValueError("FlStatistics.NUM_OF_BINS not specified in `extra`")
+ else:
+ hist_bins = extra[FlStatistics.HIST_BINS]
+ if FlStatistics.HIST_RANGE not in extra:
+ raise ValueError("FlStatistics.HIST_RANGE not specified in `extra`")
+ else:
+ hist_range = extra[FlStatistics.HIST_RANGE]
+
+ stats_dict = {}
+
+ # train data stats
+ train_summary_stats, train_case_stats = self._get_data_key_stats(
+ parser=self.train_parser,
+ data_key=self.train_data_key,
+ hist_bins=hist_bins,
+ hist_range=hist_range,
+ output_path=os.path.join(self.app_root, "train_data_stats.yaml"),
+ )
+ if train_case_stats:
+ # Only return summary statistics to FL server
+ stats_dict.update({self.train_data_key: train_summary_stats})
+
+ # eval data stats
+ eval_summary_stats, eval_case_stats = self._get_data_key_stats(
+ parser=self.train_parser,
+ data_key=self.eval_data_key,
+ hist_bins=hist_bins,
+ hist_range=hist_range,
+ output_path=os.path.join(self.app_root, "eval_data_stats.yaml"),
+ )
+ if eval_summary_stats:
+ # Only return summary statistics to FL server
+ stats_dict.update({self.eval_data_key: eval_summary_stats})
+
+ # total stats
+ if train_case_stats and eval_case_stats:
+ # Compute total summary
+ total_summary_stats = self._compute_total_stats(
+ [train_case_stats, eval_case_stats], hist_bins, hist_range
+ )
+ stats_dict.update({FlStatistics.TOTAL_DATA: total_summary_stats})
+
+ # optional filter of data stats
+ stats = ExchangeObject(statistics=stats_dict)
+ if self.post_statistics_filters is not None:
+ for _filter in self.post_statistics_filters:
+ stats = _filter(stats, extra)
+
+ return stats
+ else:
+ raise ValueError("data_root not set!")
+
+ def _get_data_key_stats(self, parser, data_key, hist_bins, hist_range, output_path=None):
+ if data_key not in parser:
+ self.logger.warning(f"Data key {data_key} not available in bundle configs.")
+ return None, None
+ data = parser.get_parsed_content(data_key)
+
+ datalist = {data_key: data}
+
+ analyzer = DataAnalyzer(
+ datalist=datalist,
+ dataroot=self.dataset_root,
+ hist_bins=hist_bins,
+ hist_range=hist_range,
+ output_path=output_path,
+ histogram_only=self.histogram_only,
+ )
+
+ self.logger.info(f"{self.client_name} compute data statistics on {data_key}...")
+ all_stats = analyzer.get_all_case_stats(transform_list=self.data_stats_transform_list, key=data_key)
+
+ case_stats = all_stats[DataStatsKeys.BY_CASE]
+
+ summary_stats = {
+ FlStatistics.DATA_STATS: all_stats[DataStatsKeys.SUMMARY],
+ FlStatistics.DATA_COUNT: len(data),
+ FlStatistics.FAIL_COUNT: len(data) - len(case_stats),
+ # TODO: add shapes, voxels sizes, etc.
+ }
+
+ return summary_stats, case_stats
+
+ @staticmethod
+ def _compute_total_stats(case_stats_lists, hist_bins, hist_range):
+ # Compute total summary
+ total_case_stats = []
+ for case_stats_list in case_stats_lists:
+ total_case_stats += case_stats_list
+
+ summarizer = SegSummarizer(
+ "image", "label", average=True, do_ccp=True, hist_bins=hist_bins, hist_range=hist_range
+ )
+ total_summary_stats = summarizer.summarize(total_case_stats)
+
+ summary_stats = {
+ FlStatistics.DATA_STATS: total_summary_stats,
+ FlStatistics.DATA_COUNT: len(total_case_stats),
+ FlStatistics.FAIL_COUNT: 0,
+ }
+
+ return summary_stats
+
+ def _add_config_files(self, config_files):
+ files = []
+ if config_files:
+ if isinstance(config_files, str):
+ files.append(os.path.join(self.bundle_root, config_files))
+ elif isinstance(config_files, list):
+ for file in config_files:
+ if isinstance(file, str):
+ files.append(os.path.join(self.bundle_root, file))
+ else:
+ raise ValueError(f"Expected config file to be of type str but got {type(file)}: {file}")
+ else:
+ raise ValueError(
+ f"Expected config files to be of type str or list but got {type(config_files)}: {config_files}"
+ )
+ return files
+
+
+@require_pkg(pkg_name="ignite", version="0.4.10", version_checker=min_version)
+class MonaiAlgo(ClientAlgo, MonaiAlgoStats):
+ """
+ Implementation of ``ClientAlgo`` to allow federated learning with MONAI bundle configurations.
+
+ Args:
+ bundle_root: path of bundle.
+ local_epochs: number of local epochs to execute during each round of local training; defaults to 1.
+ send_weight_diff: whether to send weight differences rather than full weights; defaults to `True`.
+ config_train_filename: bundle training config path relative to bundle_root. Can be a list of files;
+ defaults to "configs/train.json".
+ config_evaluate_filename: bundle evaluation config path relative to bundle_root. Can be a list of files.
+ If "default", config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"] will be used;
+ config_filters_filename: filter configuration file. Can be a list of files; defaults to `None`.
+ disable_ckpt_loading: do not use any CheckpointLoader if defined in train/evaluate configs; defaults to `True`.
+ best_model_filepath: location of best model checkpoint; defaults "models/model.pt" relative to `bundle_root`.
+ final_model_filepath: location of final model checkpoint; defaults "models/model_final.pt" relative to `bundle_root`.
+ save_dict_key: If a model checkpoint contains several state dicts,
+ the one defined by `save_dict_key` will be returned by `get_weights`; defaults to "model".
+ If all state dicts should be returned, set `save_dict_key` to None.
+ seed: set random seed for modules to enable or disable deterministic training; defaults to `None`,
+ i.e., non-deterministic training.
+ benchmark: set benchmark to `False` for full deterministic behavior in cuDNN components.
+ Note, full determinism in federated learning depends also on deterministic behavior of other FL components,
+ e.g., the aggregator, which is not controlled by this class.
+ multi_gpu: whether to run MonaiAlgo in a multi-GPU setting; defaults to `False`.
+ backend: backend to use for torch.distributed; defaults to "nccl".
+ init_method: init_method for torch.distributed; defaults to "env://".
+ """
+
+ def __init__(
+ self,
+ bundle_root: str,
+ local_epochs: int = 1,
+ send_weight_diff: bool = True,
+ config_train_filename: Optional[Union[str, list]] = "configs/train.json",
+ config_evaluate_filename: Optional[Union[str, list]] = "default",
+ config_filters_filename: Optional[Union[str, list]] = None,
+ disable_ckpt_loading: bool = True,
+ best_model_filepath: Optional[str] = "models/model.pt",
+ final_model_filepath: Optional[str] = "models/model_final.pt",
+ save_dict_key: Optional[str] = "model",
+ seed: Optional[int] = None,
+ benchmark: bool = True,
+ multi_gpu: bool = False,
+ backend: str = "nccl",
+ init_method: str = "env://",
+ train_data_key: Optional[str] = BundleKeys.TRAIN_DATA,
+ eval_data_key: Optional[str] = BundleKeys.VALID_DATA,
+ data_stats_transform_list: Optional[list] = None,
+ ):
+ self.logger = logging.getLogger(self.__class__.__name__)
+ if config_evaluate_filename == "default":
+ # by default, evaluator needs both training and evaluate to be instantiated.
+ config_evaluate_filename = ["configs/train.json", "configs/evaluate.json"]
+ self.bundle_root = bundle_root
+ self.local_epochs = local_epochs
+ self.send_weight_diff = send_weight_diff
+ self.config_train_filename = config_train_filename
+ self.config_evaluate_filename = config_evaluate_filename
+ self.config_filters_filename = config_filters_filename
+ self.disable_ckpt_loading = disable_ckpt_loading
+ self.model_filepaths = {ModelType.BEST_MODEL: best_model_filepath, ModelType.FINAL_MODEL: final_model_filepath}
+ self.save_dict_key = save_dict_key
+ self.seed = seed
+ self.benchmark = benchmark
+ self.multi_gpu = multi_gpu
+ self.backend = backend
+ self.init_method = init_method
+ self.train_data_key = train_data_key
+ self.eval_data_key = eval_data_key
+ self.data_stats_transform_list = data_stats_transform_list
+
+ self.app_root = None
+ self.train_parser = None
+ self.eval_parser = None
+ self.filter_parser = None
+ self.trainer = None
+ self.evaluator = None
+ self.pre_filters = None
+ self.post_weight_filters = None
+ self.post_evaluate_filters = None
+ self.iter_of_start_time = 0
+ self.global_weights = None
+ self.rank = 0
+
+ self.phase = FlPhase.IDLE
+ self.client_name = None
+ self.dataset_root = None
+
+ def initialize(self, extra=None):
+ """
+ Initialize routine to parse configuration files and extract main components such as trainer, evaluator, and filters.
+
+ Args:
+ extra: Dict with additional information that should be provided by FL system,
+ i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
+
+ """
+ if extra is None:
+ extra = {}
+ self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
+ self.logger.info(f"Initializing {self.client_name} ...")
+
+ if self.multi_gpu:
+ dist.init_process_group(backend=self.backend, init_method=self.init_method)
+ self._set_cuda_device()
+ self.logger.info(
+ f"Using multi-gpu training on rank {self.rank} (available devices: {torch.cuda.device_count()})"
+ )
+ if self.rank > 0:
+ self.logger.setLevel(logging.WARNING)
+
+ if self.seed:
+ monai.utils.set_determinism(seed=self.seed)
+ torch.backends.cudnn.benchmark = self.benchmark
+
+ # FL platform needs to provide filepath to configuration files
+ self.app_root = extra.get(ExtraItems.APP_ROOT, "")
+
+ # Read bundle config files
+ self.bundle_root = os.path.join(self.app_root, self.bundle_root)
+
+ config_train_files = self._add_config_files(self.config_train_filename)
+ config_eval_files = self._add_config_files(self.config_evaluate_filename)
+ config_filter_files = self._add_config_files(self.config_filters_filename)
+
+ # Parse
+ self.train_parser = ConfigParser()
+ self.eval_parser = ConfigParser()
+ self.filter_parser = ConfigParser()
+ if len(config_train_files) > 0:
+ self.train_parser.read_config(config_train_files)
+ check_bundle_config(self.train_parser)
+ if len(config_eval_files) > 0:
+ self.eval_parser.read_config(config_eval_files)
+ check_bundle_config(self.eval_parser)
+ if len(config_filter_files) > 0:
+ self.filter_parser.read_config(config_filter_files)
+
+ # override some config items
+ self.train_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root
+ self.eval_parser[RequiredBundleKeys.BUNDLE_ROOT] = self.bundle_root
+ # number of training epochs for each round
+ if BundleKeys.TRAIN_TRAINER_MAX_EPOCHS in self.train_parser:
+ self.train_parser[BundleKeys.TRAIN_TRAINER_MAX_EPOCHS] = self.local_epochs
+
+ # remove checkpoint loaders
+ if self.disable_ckpt_loading:
+ disable_ckpt_loaders(self.train_parser)
+ disable_ckpt_loaders(self.eval_parser)
+
+ # Get trainer, evaluator
+ self.trainer = self.train_parser.get_parsed_content(
+ BundleKeys.TRAINER, default=ConfigItem(None, BundleKeys.TRAINER)
+ )
+ self.evaluator = self.eval_parser.get_parsed_content(
+ BundleKeys.EVALUATOR, default=ConfigItem(None, BundleKeys.EVALUATOR)
+ )
+
+ # Get filters
+ self.pre_filters = self.filter_parser.get_parsed_content(
+ FiltersType.PRE_FILTERS, default=ConfigItem(None, FiltersType.PRE_FILTERS)
+ )
+ self.post_weight_filters = self.filter_parser.get_parsed_content(
+ FiltersType.POST_WEIGHT_FILTERS, default=ConfigItem(None, FiltersType.POST_WEIGHT_FILTERS)
+ )
+ self.post_evaluate_filters = self.filter_parser.get_parsed_content(
+ FiltersType.POST_EVALUATE_FILTERS, default=ConfigItem(None, FiltersType.POST_EVALUATE_FILTERS)
+ )
+ self.post_statistics_filters = self.filter_parser.get_parsed_content(
+ FiltersType.POST_STATISTICS_FILTERS, default=ConfigItem(None, FiltersType.POST_STATISTICS_FILTERS)
+ )
+
+ # Get data location
+ self.dataset_root = self.train_parser.get_parsed_content(
+ BundleKeys.DATASET_DIR, default=ConfigItem(None, BundleKeys.DATASET_DIR)
+ )
+
+ if self.multi_gpu:
+ if self.rank > 0 and self.trainer:
+ self.trainer.logger.setLevel(logging.WARNING)
+ if self.rank > 0 and self.evaluator:
+ self.evaluator.logger.setLevel(logging.WARNING)
+ self.logger.info(f"Initialized {self.client_name}.")
+
+ def train(self, data: ExchangeObject, extra=None):
+ """
+ Train on client's local data.
+
+ Args:
+ data: `ExchangeObject` containing the current global model weights.
+ extra: Dict with additional information that can be provided by the FL system.
+
+ """
+ self._set_cuda_device()
+
+ if extra is None:
+ extra = {}
+ if not isinstance(data, ExchangeObject):
+ raise ValueError(f"expected data to be ExchangeObject but received {type(data)}")
+
+ if self.trainer is None:
+ raise ValueError("self.trainer should not be None.")
+ if self.pre_filters is not None:
+ for _filter in self.pre_filters:
+ data = _filter(data, extra)
+ self.phase = FlPhase.TRAIN
+ self.logger.info(f"Load {self.client_name} weights...")
+ local_var_dict = get_state_dict(self.trainer.network)
+ self.global_weights, n_converted = convert_global_weights(
+ global_weights=data.weights, local_var_dict=local_var_dict
+ )
+ self._check_converted(data.weights, local_var_dict, n_converted)
+
+ # set engine state max epochs.
+ self.trainer.state.max_epochs = self.trainer.state.epoch + self.local_epochs
+ # get current iteration when a round starts
+ self.iter_of_start_time = self.trainer.state.iteration
+
+ _, updated_keys, _ = copy_model_state(src=self.global_weights, dst=self.trainer.network)
+ if len(updated_keys) == 0:
+ self.logger.warning("No weights loaded!")
+ self.logger.info(f"Start {self.client_name} training...")
+ self.trainer.run()
+
+ def get_weights(self, extra=None):
+ """
+ Returns the current weights of the model.
+
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ return_weights: `ExchangeObject` containing current weights (default)
+ or load requested model type from disk (`ModelType.BEST_MODEL` or `ModelType.FINAL_MODEL`).
+
+ """
+ self._set_cuda_device()
+
+ if extra is None:
+ extra = {}
+
+ # by default return current weights, return best if requested via model type.
+ self.phase = FlPhase.GET_WEIGHTS
+
+ if ExtraItems.MODEL_TYPE in extra:
+ model_type = extra.get(ExtraItems.MODEL_TYPE)
+ if not isinstance(model_type, ModelType):
+ raise ValueError(
+ f"Expected requested model type to be of type `ModelType` but received {type(model_type)}"
+ )
+ if model_type in self.model_filepaths:
+ model_path = os.path.join(self.bundle_root, self.model_filepaths[model_type])
+ if not os.path.isfile(model_path):
+ raise ValueError(f"No best model checkpoint exists at {model_path}")
+ weights = torch.load(model_path, map_location="cpu")
+ # if weights contain several state dicts, use the one defined by `save_dict_key`
+ if isinstance(weights, dict) and self.save_dict_key in weights:
+ weights = weights.get(self.save_dict_key)
+ weigh_type = WeightType.WEIGHTS
+ stats = dict()
+ self.logger.info(f"Returning {model_type} checkpoint weights from {model_path}.")
+ else:
+ raise ValueError(
+ f"Requested model type {model_type} not specified in `model_filepaths`: {self.model_filepaths}"
+ )
+ else:
+ if self.trainer:
+ weights = get_state_dict(self.trainer.network)
+ # returned weights will be on the cpu
+ for k in weights.keys():
+ weights[k] = weights[k].cpu()
+ weigh_type = WeightType.WEIGHTS
+ stats = self.trainer.get_stats()
+ # calculate current iteration and epoch data after training.
+ stats[FlStatistics.NUM_EXECUTED_ITERATIONS] = self.trainer.state.iteration - self.iter_of_start_time
+ # compute weight differences
+ if self.send_weight_diff:
+ weights = compute_weight_diff(global_weights=self.global_weights, local_var_dict=weights)
+ weigh_type = WeightType.WEIGHT_DIFF
+ self.logger.info("Returning current weight differences.")
+ else:
+ self.logger.info("Returning current weights.")
+ else:
+ weights = None
+ weigh_type = None
+ stats = dict()
+
+ if not isinstance(stats, dict):
+ raise ValueError(f"stats is not a dict, {stats}")
+ return_weights = ExchangeObject(
+ weights=weights,
+ optim=None, # could be self.optimizer.state_dict()
+ weight_type=weigh_type,
+ statistics=stats,
+ )
+
+ # filter weights if needed (use to apply differential privacy, encryption, compression, etc.)
+ if self.post_weight_filters is not None:
+ for _filter in self.post_weight_filters:
+ return_weights = _filter(return_weights, extra)
+
+ return return_weights
+
+ def evaluate(self, data: ExchangeObject, extra=None):
+ """
+ Evaluate on client's local data.
+
+ Args:
+ data: `ExchangeObject` containing the current global model weights.
+ extra: Dict with additional information that can be provided by the FL system.
+
+ Returns:
+ return_metrics: `ExchangeObject` containing evaluation metrics.
+
+ """
+ self._set_cuda_device()
+
+ if extra is None:
+ extra = {}
+ if not isinstance(data, ExchangeObject):
+ raise ValueError(f"expected data to be ExchangeObject but received {type(data)}")
+
+ if self.evaluator is None:
+ raise ValueError("self.evaluator should not be None.")
+ if self.pre_filters is not None:
+ for _filter in self.pre_filters:
+ data = _filter(data, extra)
+
+ self.phase = FlPhase.EVALUATE
+ self.logger.info(f"Load {self.client_name} weights...")
+ local_var_dict = get_state_dict(self.evaluator.network)
+ global_weights, n_converted = convert_global_weights(global_weights=data.weights, local_var_dict=local_var_dict)
+ self._check_converted(data.weights, local_var_dict, n_converted)
+
+ _, updated_keys, _ = copy_model_state(src=global_weights, dst=self.evaluator.network)
+ if len(updated_keys) == 0:
+ self.logger.warning("No weights loaded!")
+ self.logger.info(f"Start {self.client_name} evaluating...")
+ if isinstance(self.trainer, monai.engines.Trainer):
+ self.evaluator.run(self.trainer.state.epoch + 1)
+ else:
+ self.evaluator.run()
+ return_metrics = ExchangeObject(metrics=self.evaluator.state.metrics)
+
+ if self.post_evaluate_filters is not None:
+ for _filter in self.post_evaluate_filters:
+ return_metrics = _filter(return_metrics, extra)
+ return return_metrics
+
+ def abort(self, extra=None):
+ """
+ Abort the training or evaluation.
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+ """
+ self.logger.info(f"Aborting {self.client_name} during {self.phase} phase.")
+ if isinstance(self.trainer, monai.engines.Trainer):
+ self.logger.info(f"Aborting {self.client_name} trainer...")
+ self.trainer.interrupt()
+ if isinstance(self.evaluator, monai.engines.Trainer):
+ self.logger.info(f"Aborting {self.client_name} evaluator...")
+ self.evaluator.interrupt()
+
+ def finalize(self, extra=None):
+ """
+ Finalize the training or evaluation.
+ Args:
+ extra: Dict with additional information that can be provided by the FL system.
+ """
+ self.logger.info(f"Terminating {self.client_name} during {self.phase} phase.")
+ if isinstance(self.trainer, monai.engines.Trainer):
+ self.logger.info(f"Terminating {self.client_name} trainer...")
+ self.trainer.terminate()
+ if isinstance(self.evaluator, monai.engines.Trainer):
+ self.logger.info(f"Terminating {self.client_name} evaluator...")
+ self.evaluator.terminate()
+
+ if self.multi_gpu:
+ dist.destroy_process_group()
+
+ def _check_converted(self, global_weights, local_var_dict, n_converted):
+ if n_converted == 0:
+ self.logger.warning(
+ f"No global weights converted! Received weight dict keys are {list(global_weights.keys())}"
+ )
+ else:
+ self.logger.info(
+ f"Converted {n_converted} global variables to match {len(local_var_dict)} local variables."
+ )
+
+ def _set_cuda_device(self):
+ if self.multi_gpu:
+ self.rank = int(os.environ["LOCAL_RANK"])
+ torch.cuda.set_device(self.rank)
diff --git a/monai/fl/utils/__init__.py b/monai/fl/utils/__init__.py
new file mode 100644
index 00000000000..1e97f894078
--- /dev/null
+++ b/monai/fl/utils/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py
new file mode 100644
index 00000000000..cd24e6093de
--- /dev/null
+++ b/monai/fl/utils/constants.py
@@ -0,0 +1,70 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from monai.utils.enums import StrEnum
+
+
+class WeightType(StrEnum):
+ WEIGHTS = "fl_weights_full"
+ WEIGHT_DIFF = "fl_weight_diff"
+
+
+class ModelType(StrEnum):
+ BEST_MODEL = "fl_best_model"
+ FINAL_MODEL = "fl_final_model"
+
+
+class ExtraItems(StrEnum):
+ ABORT = "fl_abort"
+ MODEL_TYPE = "fl_model_type"
+ CLIENT_NAME = "fl_client_name"
+ APP_ROOT = "fl_app_root"
+
+
+class FlPhase(StrEnum):
+ IDLE = "fl_idle"
+ TRAIN = "fl_train"
+ EVALUATE = "fl_evaluate"
+ GET_WEIGHTS = "fl_get_weights"
+ GET_DATA_STATS = "fl_get_data_stats"
+
+
+class FlStatistics(StrEnum):
+ NUM_EXECUTED_ITERATIONS = "num_executed_iterations"
+ STATISTICS = "statistics"
+ HIST_BINS = "hist_bins"
+ HIST_RANGE = "hist_range"
+ DATA_STATS = "data_stats"
+ DATA_COUNT = "data_count"
+ FAIL_COUNT = "fail_count"
+ TOTAL_DATA = "total_data"
+ FEATURE_NAMES = "feature_names"
+
+
+class RequiredBundleKeys(StrEnum):
+ BUNDLE_ROOT = "bundle_root"
+
+
+class BundleKeys(StrEnum):
+ TRAINER = "train#trainer"
+ EVALUATOR = "validate#evaluator"
+ TRAIN_TRAINER_MAX_EPOCHS = "train#trainer#max_epochs"
+ VALIDATE_HANDLERS = "validate#handlers"
+ DATASET_DIR = "dataset_dir"
+ TRAIN_DATA = "train#dataset#data"
+ VALID_DATA = "validate#dataset#data"
+
+
+class FiltersType(StrEnum):
+ PRE_FILTERS = "pre_filters"
+ POST_WEIGHT_FILTERS = "post_weight_filters"
+ POST_EVALUATE_FILTERS = "post_evaluate_filters"
+ POST_STATISTICS_FILTERS = "post_statistics_filters"
diff --git a/monai/fl/utils/exchange_object.py b/monai/fl/utils/exchange_object.py
new file mode 100644
index 00000000000..9772de8e422
--- /dev/null
+++ b/monai/fl/utils/exchange_object.py
@@ -0,0 +1,107 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, Optional
+
+from monai.fl.utils.constants import WeightType
+
+
+class ExchangeObject(dict):
+ """
+ Contains the information shared between client and server.
+
+ Args:
+ weights: model weights.
+ optim: optimizer weights.
+ metrics: evaluation metrics.
+ weight_type: type of weights (see monai.fl.utils.constants.WeightType).
+ statistics: training statistics, i.e. number executed iterations.
+ """
+
+ def __init__(
+ self,
+ weights=None,
+ optim=None,
+ metrics: Optional[Dict] = None,
+ weight_type: Optional[Dict] = None,
+ statistics: Optional[Dict] = None,
+ ):
+ super().__init__()
+ self.weights = weights
+ self.optim = optim
+ self.metrics = metrics
+ self.weight_type = weight_type
+ self.statistics = statistics
+ self._summary: Dict = {}
+
+ @property
+ def metrics(self):
+ return self._metrics
+
+ @metrics.setter
+ def metrics(self, metrics):
+ if metrics is not None:
+ if not isinstance(metrics, dict):
+ raise ValueError(f"Expected metrics to be of type dict but received {type(metrics)}")
+ self._metrics = metrics
+
+ @property
+ def statistics(self):
+ return self._statistics
+
+ @statistics.setter
+ def statistics(self, statistics):
+ if statistics is not None:
+ if not isinstance(statistics, dict):
+ raise ValueError(f"Expected statistics to be of type dict but received {type(statistics)}")
+ self._statistics = statistics
+
+ @property
+ def weight_type(self):
+ return self._weight_type
+
+ @weight_type.setter
+ def weight_type(self, weight_type):
+ if weight_type is not None:
+ if weight_type not in [WeightType.WEIGHTS, WeightType.WEIGHT_DIFF]:
+ raise ValueError(f"Expected weight type to be either {WeightType.WEIGHTS} or {WeightType.WEIGHT_DIFF}")
+ self._weight_type = weight_type
+
+ def is_valid_weights(self):
+ if not self.weights:
+ return False
+ if not self.weight_type:
+ return False
+ return True
+
+ def _add_to_summary(self, key, value):
+ if value:
+ if isinstance(value, dict):
+ self._summary[key] = len(value)
+ elif isinstance(value, WeightType):
+ self._summary[key] = value
+ else:
+ self._summary[key] = type(value)
+
+ def summary(self):
+ self._summary.update(self)
+ for k, v in zip(
+ ["weights", "optim", "metrics", "weight_type", "statistics"],
+ [self.weights, self.optim, self.metrics, self.weight_type, self.statistics],
+ ):
+ self._add_to_summary(k, v)
+ return self._summary
+
+ def __repr__(self):
+ return str(self.summary())
+
+ def __str__(self):
+ return str(self.summary())
diff --git a/monai/fl/utils/filters.py b/monai/fl/utils/filters.py
new file mode 100644
index 00000000000..b205ffe6681
--- /dev/null
+++ b/monai/fl/utils/filters.py
@@ -0,0 +1,55 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+
+from monai.fl.utils.exchange_object import ExchangeObject
+
+
+class Filter(abc.ABC):
+ """
+ Used to apply filter to content of ExchangeObject.
+ """
+
+ @abc.abstractmethod
+ def __call__(self, data: ExchangeObject, extra=None) -> ExchangeObject:
+ """
+ Run the filtering.
+
+ Arguments:
+ data: ExchangeObject containing some data.
+
+ Returns:
+ ExchangeObject: filtered data.
+ """
+
+ raise NotImplementedError
+
+
+class SummaryFilter(Filter):
+ """
+ Summary filter to content of ExchangeObject.
+ """
+
+ def __call__(self, data: ExchangeObject, extra=None) -> ExchangeObject:
+ """
+ Example filter that doesn't filter anything but only prints data summary.
+
+ Arguments:
+ data: ExchangeObject containing some data.
+
+ Returns:
+ ExchangeObject: filtered data.
+ """
+
+ print(f"Summary of ExchangeObject: {data.summary()}")
+
+ return data
diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py
index cffbe463914..9880e39817f 100644
--- a/monai/handlers/__init__.py
+++ b/monai/handlers/__init__.py
@@ -18,12 +18,15 @@
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
from .ignite_metric import IgniteMetric
+from .logfile_handler import LogfileHandler
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
+from .mean_iou import MeanIoUHandler
from .metric_logger import MetricLogger, MetricLoggerKeys
from .metrics_saver import MetricsSaver
from .mlflow_handler import MLFlowHandler
from .nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
+from .panoptic_quality import PanopticQuality
from .parameter_scheduler import ParamSchedulerHandler
from .postprocessing import PostProcessing
from .probability_maps import ProbMapProducer
diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py
index f7abca4aa03..d415b787e2a 100644
--- a/monai/handlers/checkpoint_saver.py
+++ b/monai/handlers/checkpoint_saver.py
@@ -62,7 +62,7 @@ class CheckpointSaver:
https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html.
typically, it's used to resume training and compare current metric with previous N values.
key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise,
- save the the first equally scored model. default to `False`.
+ save the first equally scored model. default to `False`.
key_metric_negative_sign: whether adding a negative sign to the metric score to compare metrics,
because for error-like metrics, smaller is better(objects with larger score are retained).
default to `False`.
diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py
index 8d57526676a..4f61fa3e006 100644
--- a/monai/handlers/earlystop_handler.py
+++ b/monai/handlers/earlystop_handler.py
@@ -20,7 +20,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
class EarlyStopHandler:
diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py
index f28923af687..d6f3f501443 100644
--- a/monai/handlers/ignite_metric.py
+++ b/monai/handlers/ignite_metric.py
@@ -19,9 +19,9 @@
from monai.utils import min_version, optional_import
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
-Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
+Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
reinit__is_reduced, _ = optional_import(
- "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced"
+ "ignite.metrics.metric", IgniteInfo.OPT_IMPORT_VERSION, min_version, "reinit__is_reduced", as_type="decorator"
)
if TYPE_CHECKING:
from ignite.engine import Engine
diff --git a/monai/handlers/logfile_handler.py b/monai/handlers/logfile_handler.py
new file mode 100644
index 00000000000..73c58431a9b
--- /dev/null
+++ b/monai/handlers/logfile_handler.py
@@ -0,0 +1,89 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from typing import TYPE_CHECKING, Optional
+
+from monai.config import IgniteInfo
+from monai.utils import min_version, optional_import
+
+Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
+if TYPE_CHECKING:
+ from ignite.engine import Engine
+else:
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+
+__all__ = ["LogfileHandler"]
+
+
+class LogfileHandler:
+ """
+ Adds a `logging.FileHandler` to the attached engine's logger when the start event occurs and removes it again when
+ then completed event occurs.
+
+ A handler is needed to remove `FileHandler` object when the complete event occurs so that further runs of different
+ engines write only to the log files they should, rather than previous files. Multiple handlers can write to the same
+ file which allows output from train and evaluation engine objects to be condensed in one file. If the given output
+ directory doesn't exist it will by default be created when the start event occurs. This can be used in conjunction
+ with `CheckpointSaver` to save a log file to the same destination as the saved checkpoints. Since the handler is
+ added possibly after other logging events during initialisation, not all logging data will be retained.
+
+ Args:
+ output_dir: directory to save the log file to
+ filename: name of the file to save log to
+ loglevel: log level for the handler
+ formatter: format string for the `logging.Formatter` set for the handler
+ create_dir: if True, create `output_dir` if it doesn't exist
+ """
+
+ def __init__(
+ self,
+ output_dir: str,
+ filename: str = "log.txt",
+ loglevel: int = logging.INFO,
+ formatter: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
+ create_dir: bool = True,
+ ):
+ self.output_dir: str = output_dir
+ self.filename: str = filename
+ self.loglevel: int = loglevel
+ self.formatter: str = formatter
+ self.create_dir: bool = create_dir
+ self.logger: Optional[logging.Logger] = None
+ self.handler: Optional[logging.FileHandler] = None
+
+ def attach(self, engine: Engine) -> None:
+ self.logger = engine.logger
+ engine.add_event_handler(Events.STARTED, self._start)
+ engine.add_event_handler(Events.COMPLETED, self._completed)
+
+ def _start(self, engine: Engine) -> None:
+ if self.create_dir and not os.path.exists(self.output_dir):
+ os.makedirs(self.output_dir, exist_ok=True)
+
+ self.handler = logging.FileHandler(os.path.join(self.output_dir, self.filename))
+ self.handler.setLevel(self.loglevel)
+ self.handler.setFormatter(logging.Formatter(self.formatter))
+
+ if self.logger is not None:
+ self.logger.addHandler(self.handler)
+ else:
+ raise AttributeError("`self.logger` must not be None in start event")
+
+ def _completed(self, engine: Engine) -> None:
+ if self.logger is not None and self.handler is not None:
+ self.logger.removeHandler(self.handler)
+ self.handler.close()
+ else:
+ raise AttributeError("`self.logger` and `self.handler` must not be None in complete event")
+
+ self.handler = None
diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py
index db186bd73d2..66059bba952 100644
--- a/monai/handlers/lr_schedule_handler.py
+++ b/monai/handlers/lr_schedule_handler.py
@@ -21,7 +21,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
class LrScheduleHandler:
diff --git a/monai/handlers/mean_iou.py b/monai/handlers/mean_iou.py
new file mode 100644
index 00000000000..ee4602e6a7f
--- /dev/null
+++ b/monai/handlers/mean_iou.py
@@ -0,0 +1,52 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Union
+
+from monai.handlers.ignite_metric import IgniteMetric
+from monai.metrics import MeanIoU
+from monai.utils import MetricReduction
+
+
+class MeanIoUHandler(IgniteMetric):
+ """
+ Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.
+ """
+
+ def __init__(
+ self,
+ include_background: bool = True,
+ reduction: Union[MetricReduction, str] = MetricReduction.MEAN,
+ output_transform: Callable = lambda x: x,
+ save_details: bool = True,
+ ) -> None:
+ """
+
+ Args:
+ include_background: whether to include iou computation on the first channel of the predicted output.
+ Defaults to True.
+ reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
+ output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
+ construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
+ lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
+ `engine.state` and `output_transform` inherit from the ignite concept:
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
+ save_details: whether to save metric computation details per image, for example: mean iou of every image.
+ default to True, will save to `engine.state.metric_details` dict with the metric name as key.
+
+ See also:
+ :py:meth:`monai.metrics.meaniou.compute_meaniou`
+ """
+ metric_fn = MeanIoU(include_background=include_background, reduction=reduction)
+ super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py
index 350d1978dee..334f631b88a 100644
--- a/monai/handlers/metric_logger.py
+++ b/monai/handlers/metric_logger.py
@@ -22,7 +22,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS):
diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py
index 664a1c87300..060738df362 100644
--- a/monai/handlers/mlflow_handler.py
+++ b/monai/handlers/mlflow_handler.py
@@ -22,7 +22,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
DEFAULT_TAG = "Loss"
diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py
index 327c156f631..19f1b2e2bb5 100644
--- a/monai/handlers/nvtx_handlers.py
+++ b/monai/handlers/nvtx_handlers.py
@@ -21,9 +21,12 @@
if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
- Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
-
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
+ Events, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events", as_type="decorator"
+ )
__all__ = ["RangeHandler", "RangePushHandler", "RangePopHandler", "MarkHandler"]
diff --git a/monai/handlers/panoptic_quality.py b/monai/handlers/panoptic_quality.py
new file mode 100644
index 00000000000..d9e5beec594
--- /dev/null
+++ b/monai/handlers/panoptic_quality.py
@@ -0,0 +1,67 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, Union
+
+from monai.handlers.ignite_metric import IgniteMetric
+from monai.metrics import PanopticQualityMetric
+from monai.utils import MetricReduction
+
+
+class PanopticQuality(IgniteMetric):
+ """
+ Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ metric_name: str = "pq",
+ reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH,
+ match_iou_threshold: float = 0.5,
+ smooth_numerator: float = 1e-6,
+ output_transform: Callable = lambda x: x,
+ save_details: bool = True,
+ ) -> None:
+ """
+
+ Args:
+ num_classes: number of classes. The number should not count the background.
+ metric_name: output metric. The value can be "pq", "sq" or "rq".
+ reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
+ ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
+ match_iou_threshold: IOU threshould to determine the pairing between `y_pred` and `y`. Usually,
+ it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.
+ If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
+ maximal amout of unique pairing.
+ smooth_numerator: a small constant added to the numerator to avoid zero.
+ output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
+ construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
+ lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
+ `engine.state` and `output_transform` inherit from the ignite concept:
+ https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
+ https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
+ save_details: whether to save metric computation details per image, for example: panoptic quality of
+ every image.
+ default to True, will save to `engine.state.metric_details` dict with the metric name as key.
+
+ See also:
+ :py:meth:`monai.metrics.panoptic_quality.compute_panoptic_quality`
+ """
+ metric_fn = PanopticQualityMetric(
+ num_classes=num_classes,
+ metric_name=metric_name,
+ reduction=reduction,
+ match_iou_threshold=match_iou_threshold,
+ smooth_numerator=smooth_numerator,
+ )
+ super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py
index 67c51fd351c..233abca2e0e 100644
--- a/monai/handlers/parameter_scheduler.py
+++ b/monai/handlers/parameter_scheduler.py
@@ -33,7 +33,7 @@ class ParamSchedulerHandler:
value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
or Callable for custom logic.
vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
- epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False.
+ epoch_level (bool): Whether the step is based on epoch or iteration. Defaults to False.
name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
"""
@@ -45,10 +45,10 @@ def __init__(
vc_kwargs: Dict,
epoch_level: bool = False,
name: Optional[str] = None,
- event=Events.ITERATION_COMPLETED,
+ event=None,
):
self.epoch_level = epoch_level
- self.event = event
+ self.event = event if event is not None else Events.ITERATION_COMPLETED
self._calculators = {
"linear": self._linear,
diff --git a/monai/handlers/probability_maps.py b/monai/handlers/probability_maps.py
index df20e0604e8..3afc7e938ba 100644
--- a/monai/handlers/probability_maps.py
+++ b/monai/handlers/probability_maps.py
@@ -18,6 +18,7 @@
from monai.config import DtypeLike, IgniteInfo
from monai.utils import ProbMapKeys, min_version, optional_import
+from monai.utils.enums import CommonKeys
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
@@ -28,7 +29,10 @@
class ProbMapProducer:
"""
- Event handler triggered on completing every iteration to calculate and save the probability map
+ Event handler triggered on completing every iteration to calculate and save the probability map.
+ This handler use metadata from MetaTensor to create the probability map. This can be simply achieved by using
+ `monai.data.SlidingPatchWSIDataset` or `monai.data.MaskedPatchWSIDataset` as the dataset.
+
"""
def __init__(
@@ -91,8 +95,8 @@ def __call__(self, engine: Engine) -> None:
"""
if not isinstance(engine.state.batch, dict) or not isinstance(engine.state.output, dict):
raise ValueError("engine.state.batch and engine.state.output must be dictionaries.")
- names = engine.state.batch["metadata"][ProbMapKeys.NAME]
- locs = engine.state.batch["metadata"][ProbMapKeys.LOCATION]
+ names = engine.state.batch[CommonKeys.IMAGE].meta[ProbMapKeys.NAME]
+ locs = engine.state.batch[CommonKeys.IMAGE].meta[ProbMapKeys.LOCATION]
probs = engine.state.output[self.prob_key]
for name, loc, prob in zip(names, locs, probs):
self.prob_map[name][tuple(loc)] = prob
diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py
index e3b5de2d363..9fdf5eeb9fa 100644
--- a/monai/handlers/stats_handler.py
+++ b/monai/handlers/stats_handler.py
@@ -22,7 +22,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
DEFAULT_KEY_VAL_FORMAT = "{}: {:.4f} "
DEFAULT_TAG = "Loss"
diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py
index 445e3e76ca4..14701e79d93 100644
--- a/monai/handlers/tensorboard_handlers.py
+++ b/monai/handlers/tensorboard_handlers.py
@@ -25,7 +25,9 @@
if TYPE_CHECKING:
from ignite.engine import Engine
else:
- Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+ Engine, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine", as_type="decorator"
+ )
DEFAULT_TAG = "Loss"
diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py
index 084b1021c25..1c4b6b3db2f 100644
--- a/monai/inferers/inferer.py
+++ b/monai/inferers/inferer.py
@@ -122,6 +122,9 @@ class SlidingWindowInferer(Inferer):
`inputs` and `roi_size`. Output is on the `device`.
progress: whether to print a tqdm progress bar.
cache_roi_weight_map: whether to precompute the ROI weight map.
+ cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
+ when input image volume is larger than this threshold (in pixels/volxels).
+ Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
Note:
``sw_batch_size`` denotes the max number of windows per network inference iteration,
@@ -142,8 +145,9 @@ def __init__(
device: Union[torch.device, str, None] = None,
progress: bool = False,
cache_roi_weight_map: bool = False,
+ cpu_thresh: Optional[int] = None,
) -> None:
- Inferer.__init__(self)
+ super().__init__()
self.roi_size = roi_size
self.sw_batch_size = sw_batch_size
self.overlap = overlap
@@ -154,6 +158,7 @@ def __init__(
self.sw_device = sw_device
self.device = device
self.progress = progress
+ self.cpu_thresh = cpu_thresh
# compute_importance_map takes long time when computing on cpu. We thus
# compute it once if it's static and then save it for future usage
@@ -189,6 +194,11 @@ def __call__(
kwargs: optional keyword args to be passed to ``network``.
"""
+
+ device = self.device
+ if device is None and self.cpu_thresh is not None and inputs.shape[2:].numel() > self.cpu_thresh:
+ device = "cpu" # stitch in cpu memory if image is too large
+
return sliding_window_inference(
inputs,
self.roi_size,
@@ -200,7 +210,7 @@ def __call__(
self.padding_mode,
self.cval,
self.sw_device,
- self.device,
+ device,
self.progress,
self.roi_weight_map,
*args,
diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py
index 5126b23c0a6..b8989f4f124 100644
--- a/monai/inferers/utils.py
+++ b/monai/inferers/utils.py
@@ -157,7 +157,8 @@ def sliding_window_inference(
raise RuntimeError(
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
) from e
- importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0]
+
# handle non-positive weights
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
@@ -276,9 +277,10 @@ def sliding_window_inference(
final_output = dict(zip(dict_key, output_image_list))
else:
final_output = tuple(output_image_list) # type: ignore
- final_output = final_output[0] if is_tensor_output else final_output # type: ignore
+ final_output = final_output[0] if is_tensor_output else final_output
+
if isinstance(inputs, MetaTensor):
- final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore
+ final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
return final_output
diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py
index 925649d9b17..a3c4bf1c5c3 100644
--- a/monai/losses/__init__.py
+++ b/monai/losses/__init__.py
@@ -26,6 +26,7 @@
generalized_dice_focal,
generalized_wasserstein_dice,
)
+from .ds_loss import DeepSupervisionLoss
from .focal_loss import FocalLoss
from .giou_loss import BoxGIoULoss, giou
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py
index cd5b261acf9..ad53269f82a 100644
--- a/monai/losses/contrastive.py
+++ b/monai/losses/contrastive.py
@@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from distutils.log import warn
+
import torch
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
@@ -17,7 +19,6 @@
class ContrastiveLoss(_Loss):
-
"""
Compute the Contrastive loss defined in:
@@ -30,11 +31,10 @@ class ContrastiveLoss(_Loss):
"""
@deprecated_arg(name="reduction", since="0.8", msg_suffix="`reduction` is no longer supported.")
- def __init__(self, temperature: float = 0.5, batch_size: int = 1, reduction="sum") -> None:
+ def __init__(self, temperature: float = 0.5, batch_size: int = -1, reduction="sum") -> None:
"""
Args:
temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5.
- batch_size: The number of samples.
Raises:
ValueError: When an input of dimension length > 2 is passed
@@ -46,10 +46,11 @@ def __init__(self, temperature: float = 0.5, batch_size: int = 1, reduction="sum
"""
super().__init__()
-
- self.batch_size = batch_size
self.temperature = temperature
+ if batch_size != -1:
+ warn("batch_size is no longer required to be set. It will be estimated dynamically in the forward call")
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
@@ -66,17 +67,18 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
raise ValueError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
+ batch_size = input.shape[0]
norm_i = F.normalize(input, dim=1)
norm_j = F.normalize(target, dim=1)
- negatives_mask = ~torch.eye(self.batch_size * 2, self.batch_size * 2, dtype=torch.bool)
+ negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)
repr = torch.cat([norm_i, norm_j], dim=0)
sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
- sim_ij = torch.diag(sim_matrix, self.batch_size)
- sim_ji = torch.diag(sim_matrix, -self.batch_size)
+ sim_ij = torch.diag(sim_matrix, batch_size)
+ sim_ji = torch.diag(sim_matrix, -batch_size)
positives = torch.cat([sim_ij, sim_ji], dim=0)
nominator = torch.exp(positives / temperature_tensor)
@@ -84,4 +86,4 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
- return torch.sum(loss_partial) / (2 * self.batch_size)
+ return torch.sum(loss_partial) / (2 * batch_size)
diff --git a/monai/losses/dice.py b/monai/losses/dice.py
index 892d71d06a5..1af4c519b35 100644
--- a/monai/losses/dice.py
+++ b/monai/losses/dice.py
@@ -21,7 +21,7 @@
from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.networks import one_hot
-from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option
+from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
class DiceLoss(_Loss):
@@ -60,12 +60,12 @@ def __init__(
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
- other activation layers, Defaults to ``None``. for example:
- `other_act = torch.tanh`.
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
+ ``other_act = torch.tanh``.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -247,12 +247,12 @@ def __init__(
"""
Args:
include_background: If False channel index 0 (background category) is excluded from the calculation.
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: If True, apply a sigmoid function to the prediction.
softmax: If True, apply a softmax function to the prediction.
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
- other activation layers, Defaults to ``None``. for example:
- `other_act = torch.tanh`.
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
+ ``other_act = torch.tanh``.
w_type: {``"square"``, ``"simple"``, ``"uniform"``}
Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -357,9 +357,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
w = w + infs * max_values
- final_reduce_dim = 0 if self.batch else 1
- numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
- denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
+ numer = 2.0 * (intersection * w) + self.smooth_nr
+ denom = (denominator * w) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)
if self.reduction == LossReduction.MEAN.value:
@@ -639,14 +638,14 @@ def __init__(
``reduction`` is used for both losses and other parameters are only used for dice loss.
include_background: if False channel index 0 (background category) is excluded from the calculation.
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss`.
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
- only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
+ ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"mean"``, ``"sum"``}
@@ -692,6 +691,7 @@ def __init__(
raise ValueError("lambda_ce should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce
+ self.old_pt_ver = not pytorch_after(1, 10)
def ce(self, input: torch.Tensor, target: torch.Tensor):
"""
@@ -701,12 +701,18 @@ def ce(self, input: torch.Tensor, target: torch.Tensor):
"""
n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
- if n_pred_ch == n_target_ch:
- # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
- target = torch.argmax(target, dim=1)
- else:
+ if n_pred_ch != n_target_ch and n_target_ch == 1:
target = torch.squeeze(target, dim=1)
- target = target.long()
+ target = target.long()
+ elif self.old_pt_ver:
+ warnings.warn(
+ f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
+ "Using argmax (as a workaround) to convert target to a single channel."
+ )
+ target = torch.argmax(target, dim=1)
+ elif not torch.is_floating_point(target):
+ target = target.to(dtype=input.dtype)
+
return self.cross_entropy(input, target)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
@@ -721,7 +727,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
if len(input.shape) != len(target.shape):
- raise ValueError("the number of dimensions for input and target should be the same.")
+ raise ValueError(
+ "the number of dimensions for input and target should be the same, "
+ f"got shape {input.shape} and {target.shape}."
+ )
dice_loss = self.dice(input, target)
ce_loss = self.ce(input, target)
@@ -736,6 +745,10 @@ class DiceFocalLoss(_Loss):
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
+ ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
+ ``include_background`` and ``reduction`` are used for both losses
+ and other parameters are only used for dice loss.
+
"""
def __init__(
@@ -758,18 +771,15 @@ def __init__(
) -> None:
"""
Args:
- ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
- ``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
- and other parameters are only used for dice loss.
include_background: if False channel index 0 (background category) is excluded from the calculation.
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
- only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
+ other_act: callable function to execute other activation layers, Defaults to ``None``.
+ for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -796,6 +806,8 @@ def __init__(
"""
super().__init__()
self.dice = DiceLoss(
+ include_background=include_background,
+ to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
@@ -806,7 +818,13 @@ def __init__(
smooth_dr=smooth_dr,
batch=batch,
)
- self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction)
+ self.focal = FocalLoss(
+ include_background=include_background,
+ to_onehot_y=False,
+ gamma=gamma,
+ weight=focal_weight,
+ reduction=reduction,
+ )
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_focal < 0.0:
@@ -814,7 +832,6 @@ def __init__(
self.lambda_dice = lambda_dice
self.lambda_focal = lambda_focal
self.to_onehot_y = to_onehot_y
- self.include_background = include_background
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
@@ -829,24 +846,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
if len(input.shape) != len(target.shape):
- raise ValueError("the number of dimensions for input and target should be the same.")
-
- n_pred_ch = input.shape[1]
-
+ raise ValueError(
+ "the number of dimensions for input and target should be the same, "
+ f"got shape {input.shape} and {target.shape}."
+ )
if self.to_onehot_y:
+ n_pred_ch = input.shape[1]
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)
-
- if not self.include_background:
- if n_pred_ch == 1:
- warnings.warn("single channel prediction, `include_background=False` ignored.")
- else:
- # if skipping background, removing first channel
- target = target[:, 1:]
- input = input[:, 1:]
-
dice_loss = self.dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
@@ -860,11 +869,13 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
Args:
include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
Defaults to True.
- to_onehot_y (bool, optional): whether to convert `y` into the one-hot format. Defaults to False.
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
- other_act (Optional[Callable], optional): if don't want to use sigmoid or softmax, use other callable
- function to execute other activation layers. Defaults to None.
+ other_act (Optional[Callable], optional): callable function to execute other activation layers,
+ Defaults to ``None``. for example: `other_act = torch.tanh`.
+ only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground-truth volume to a weight factor. Defaults to ``"square"``.
reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to
diff --git a/monai/losses/ds_loss.py b/monai/losses/ds_loss.py
new file mode 100644
index 00000000000..c0425ffdde0
--- /dev/null
+++ b/monai/losses/ds_loss.py
@@ -0,0 +1,85 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import torch
+import torch.nn.functional as F
+from torch.nn.modules.loss import _Loss
+
+from monai.utils import pytorch_after
+
+
+class DeepSupervisionLoss(_Loss):
+ """
+ Wrapper class around the main loss function to accept a list of tensors returned from a deeply
+ supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.
+ """
+
+ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: Optional[List[float]] = None) -> None:
+ """
+ Args:
+ loss: main loss instance, e.g DiceLoss().
+ weight_mode: {``"same"``, ``"exp"``, ``"two"``}
+ Specifies the weights calculation for each image level. Defaults to ``"exp"``.
+ - ``"same"``: all weights are equal to 1.
+ - ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
+ - ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
+ weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
+ regardless of the weight_mode
+ """
+ super().__init__()
+ self.loss = loss
+ self.weight_mode = weight_mode
+ self.weights = weights
+ self.interp_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest"
+
+ def get_weights(self, levels: int = 1) -> List[float]:
+ """
+ Calculates weights for a given number of scale levels
+ """
+ levels = max(1, levels)
+ if self.weights is not None and len(self.weights) >= levels:
+ weights = self.weights[:levels]
+ elif self.weight_mode == "same":
+ weights = [1.0] * levels
+ elif self.weight_mode == "exp":
+ weights = [max(0.5**l, 0.0625) for l in range(levels)]
+ elif self.weight_mode == "two":
+ weights = [1.0 if l == 0 else 0.5 for l in range(levels)]
+ else:
+ weights = [1.0] * levels
+
+ return weights
+
+ def get_loss(self, input: torch.Tensor, target: torch.Tensor):
+ """
+ Calculates a loss output accounting for differences in shapes,
+ and downsizing targets if necessary (using nearest neighbor interpolation)
+ Generally downsizing occurs for all level, except for the first (level==0)
+ """
+ if input.shape[2:] != target.shape[2:]:
+ target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
+ return self.loss(input, target)
+
+ def forward(self, input: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor):
+
+ if isinstance(input, (list, tuple)):
+ weights = self.get_weights(levels=len(input))
+ loss = torch.tensor(0, dtype=torch.float, device=target.device)
+ for l in range(len(input)):
+ loss += weights[l] * self.get_loss(input[l].float(), target)
+ return loss
+
+ return self.loss(input.float(), target)
+
+
+ds_loss = DeepSupervisionLoss
diff --git a/monai/losses/giou_loss.py b/monai/losses/giou_loss.py
index ec7e358f429..623e55921ba 100644
--- a/monai/losses/giou_loss.py
+++ b/monai/losses/giou_loss.py
@@ -19,7 +19,6 @@
class BoxGIoULoss(_Loss):
-
"""
Compute the generalized intersection over union (GIoU) loss of a pair of boxes.
The two inputs should have the same shape. giou_loss = 1.0 - giou
diff --git a/monai/losses/image_dissimilarity.py b/monai/losses/image_dissimilarity.py
index a06f6fb5cd3..4351199aeea 100644
--- a/monai/losses/image_dissimilarity.py
+++ b/monai/losses/image_dissimilarity.py
@@ -8,14 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Tuple, Union
+from typing import Tuple, Union
import torch
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
from monai.networks.layers import gaussian_1d, separable_filtering
-from monai.utils import LossReduction, deprecated_arg
+from monai.utils import LossReduction
from monai.utils.module import look_up_option
@@ -60,16 +60,14 @@ class LocalNormalizedCrossCorrelationLoss(_Loss):
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""
- @deprecated_arg(name="ndim", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
def __init__(
self,
spatial_dims: int = 3,
kernel_size: int = 3,
kernel_type: str = "rectangular",
reduction: Union[LossReduction, str] = LossReduction.MEAN,
- smooth_nr: float = 1e-5,
+ smooth_nr: float = 0.0,
smooth_dr: float = 1e-5,
- ndim: Optional[int] = None,
) -> None:
"""
Args:
@@ -85,13 +83,9 @@ def __init__(
smooth_nr: a small constant added to the numerator to avoid nan.
smooth_dr: a small constant added to the denominator to avoid nan.
- .. deprecated:: 0.6.0
- ``ndim`` is deprecated, use ``spatial_dims``.
"""
super().__init__(reduction=LossReduction(reduction).value)
- if ndim is not None:
- spatial_dims = ndim
self.ndim = spatial_dims
if self.ndim not in {1, 2, 3}:
raise ValueError(f"Unsupported ndim: {self.ndim}-d, only 1-d, 2-d, and 3-d inputs are supported")
@@ -102,6 +96,7 @@ def __init__(
_kernel = look_up_option(kernel_type, kernel_dict)
self.kernel = _kernel(self.kernel_size)
+ self.kernel.require_grads = False
self.kernel_vol = self.get_kernel_vol()
self.smooth_nr = float(smooth_nr)
@@ -126,14 +121,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != pred.shape:
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
- t2, p2, tp = target**2, pred**2, target * pred
+ t2, p2, tp = target * target, pred * pred, target * pred
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
+ kernels = [kernel] * self.ndim
# sum over kernel
- t_sum = separable_filtering(target, kernels=[kernel.to(pred)] * self.ndim)
- p_sum = separable_filtering(pred, kernels=[kernel.to(pred)] * self.ndim)
- t2_sum = separable_filtering(t2, kernels=[kernel.to(pred)] * self.ndim)
- p2_sum = separable_filtering(p2, kernels=[kernel.to(pred)] * self.ndim)
- tp_sum = separable_filtering(tp, kernels=[kernel.to(pred)] * self.ndim)
+ t_sum = separable_filtering(target, kernels=kernels)
+ p_sum = separable_filtering(pred, kernels=kernels)
+ t2_sum = separable_filtering(t2, kernels=kernels)
+ p2_sum = separable_filtering(p2, kernels=kernels)
+ tp_sum = separable_filtering(tp, kernels=kernels)
# average over kernel
t_avg = t_sum / kernel_vol
@@ -149,12 +145,13 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# = sum[t*p] - sum[t] * mean[p] = cross
# the following is actually squared ncc
cross = tp_sum - p_avg * t_sum
- t_var = t2_sum - t_avg * t_sum # std[t] ** 2
- p_var = p2_sum - p_avg * p_sum # std[p] ** 2
- t_var = torch.max(t_var, torch.zeros_like(t_var))
- p_var = torch.max(p_var, torch.zeros_like(p_var))
- ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var + self.smooth_dr)
- # shape = (batch, 1, D, H, W)
+ t_var = torch.max(
+ t2_sum - t_avg * t_sum, torch.as_tensor(self.smooth_dr, dtype=t2_sum.dtype, device=t2_sum.device)
+ )
+ p_var = torch.max(
+ p2_sum - p_avg * p_sum, torch.as_tensor(self.smooth_dr, dtype=p2_sum.dtype, device=p2_sum.device)
+ )
+ ncc: torch.Tensor = (cross * cross + self.smooth_nr) / (t_var * p_var)
if self.reduction == LossReduction.SUM.value:
return torch.sum(ncc).neg() # sum over the batch, channel and spatial ndims
diff --git a/monai/losses/ssim_loss.py b/monai/losses/ssim_loss.py
index 240023cdc43..4a1ceb5a167 100644
--- a/monai/losses/ssim_loss.py
+++ b/monai/losses/ssim_loss.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import torch
import torch.nn.functional as F
from torch import nn
diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py
index ee6d7d933b0..a0735c24e0e 100644
--- a/monai/losses/tversky.py
+++ b/monai/losses/tversky.py
@@ -20,7 +20,6 @@
class TverskyLoss(_Loss):
-
"""
Compute the Tversky loss defined in:
diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py
index 1f6d51beadb..1e2bdae7257 100644
--- a/monai/losses/unified_focal_loss.py
+++ b/monai/losses/unified_focal_loss.py
@@ -216,7 +216,7 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
y_true = one_hot(y_true, num_classes=self.num_classes)
if torch.max(y_true) != self.num_classes - 1:
- raise ValueError(f"Pelase make sure the number of classes is {self.num_classes-1}")
+ raise ValueError(f"Please make sure the number of classes is {self.num_classes-1}")
n_pred_ch = y_pred.shape[1]
if self.to_onehot_y:
diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py
index 4d472a148ff..ff5eb4881a6 100644
--- a/monai/metrics/__init__.py
+++ b/monai/metrics/__init__.py
@@ -9,14 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
from .cumulative_average import CumulativeAverage
from .froc import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
-from .meandice import DiceMetric, compute_meandice
-from .meaniou import MeanIoU, compute_meaniou
+from .meandice import DiceMetric, compute_dice, compute_meandice
+from .meaniou import MeanIoU, compute_iou, compute_meaniou
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
+from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
from .regression import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric, SSIMMetric
from .rocauc import ROCAUCMetric, compute_roc_auc
from .surface_dice import SurfaceDiceMetric, compute_surface_dice
diff --git a/monai/metrics/active_learning_metrics.py b/monai/metrics/active_learning_metrics.py
new file mode 100644
index 00000000000..eddc82e87af
--- /dev/null
+++ b/monai/metrics/active_learning_metrics.py
@@ -0,0 +1,205 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from typing import Any
+
+import torch
+
+from monai.metrics.utils import ignore_background
+from monai.utils import MetricReduction
+
+from .metric import Metric
+
+
+class VarianceMetric(Metric):
+ """
+ Compute the Variance of a given T-repeats N-dimensional array/tensor. The primary usage is as an uncertainty based
+ metric for Active Learning.
+
+ It can return the spatial variance/uncertainty map based on user choice or a single scalar value via mean/sum of the
+ variance for scoring purposes
+
+ Args:
+ include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
+ spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image dimensions
+ scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
+ threshold: To avoid NaN's a threshold is used to replace zero's
+
+ """
+
+ def __init__(
+ self,
+ include_background: bool = True,
+ spatial_map: bool = False,
+ scalar_reduction: str = "sum",
+ threshold: float = 0.0005,
+ ) -> None:
+ super().__init__()
+ self.include_background = include_background
+ self.spatial_map = spatial_map
+ self.scalar_reduction = scalar_reduction
+ self.threshold = threshold
+
+ def __call__(self, y_pred: Any) -> Any:
+ """
+ Args:
+ y_pred: Predicted segmentation, typically segmentation model output.
+ It must be N-repeats, repeat-first tensor [N,C,H,W,D].
+
+ Returns:
+ Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty
+
+ """
+ return compute_variance(
+ y_pred=y_pred,
+ include_background=self.include_background,
+ spatial_map=self.spatial_map,
+ scalar_reduction=self.scalar_reduction,
+ threshold=self.threshold,
+ )
+
+
+class LabelQualityScore(Metric):
+ """
+ The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
+ can be treated as a label quality score
+
+ It can be combined with variance/uncertainty for active learning frameworks to factor in the quality of label along
+ with uncertainty
+ Args:
+ include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
+ spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
+ dimensions
+ scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
+
+ """
+
+ def __init__(self, include_background: bool = True, scalar_reduction: str = "sum") -> None:
+ super().__init__()
+ self.include_background = include_background
+ self.scalar_reduction = scalar_reduction
+
+ def __call__(self, y_pred: Any, y: Any):
+ """
+ Args:
+ y_pred: Predicted segmentation, typically segmentation model output.
+ It must be N-repeats, repeat-first tensor [N,C,H,W,D].
+
+ Returns:
+ Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty
+
+ """
+ return label_quality_score(
+ y_pred=y_pred, y=y, include_background=self.include_background, scalar_reduction=self.scalar_reduction
+ )
+
+
+def compute_variance(
+ y_pred: torch.Tensor,
+ include_background: bool = True,
+ spatial_map: bool = False,
+ scalar_reduction: str = "mean",
+ threshold: float = 0.0005,
+):
+ """
+ Args:
+ y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for
+ Height, Width & Depth
+ include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
+ spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
+ dimensions
+ scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
+ threshold: To avoid NaN's a threshold is used to replace zero's
+ Returns:
+ A single scalar uncertainty/variance value or the spatial map of uncertainty/variance
+ """
+
+ # The background utils is only applicable here because instead of Batch-dimension we have repeats here
+ y_pred = y_pred.float()
+
+ if not include_background:
+ y = y_pred
+ # TODO If this utils is made to be optional for 'y' it would be nice
+ y_pred, y = ignore_background(y_pred=y_pred, y=y)
+
+ # Set any values below 0 to threshold
+ y_pred[y_pred <= 0] = threshold
+
+ n_len = len(y_pred.shape)
+
+ if n_len < 4 and spatial_map:
+ warnings.warn("Spatial map requires a 2D/3D image with N-repeats and C-channels")
+ return None
+
+ # Create new shape list
+ # The N-repeats are multiplied by channels
+ n_shape = y_pred.shape
+ new_shape = [n_shape[0] * n_shape[1]]
+ for each_dim_idx in range(2, n_len):
+ new_shape.append(n_shape[each_dim_idx])
+
+ y_reshaped = torch.reshape(y_pred, new_shape)
+ variance = torch.var(y_reshaped, dim=0, unbiased=False)
+
+ if spatial_map:
+ return variance
+
+ if scalar_reduction == MetricReduction.MEAN:
+ return torch.mean(variance)
+ if scalar_reduction == MetricReduction.SUM:
+ return torch.sum(variance)
+ raise ValueError(f"scalar_reduction={scalar_reduction} not supported.")
+
+
+def label_quality_score(
+ y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, scalar_reduction: str = "mean"
+):
+ """
+ The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
+ can be treated as a label quality score
+
+ Args:
+ y_pred: Input data of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is
+ channels and H, W, D stand for Height, Width & Depth
+ y: Ground Truth of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is channels
+ and H, W, D stand for Height, Width & Depth
+ include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
+ scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used to retrieve a single scalar
+ value, if set to 'none' a spatial map will be returned
+
+ Returns:
+ A single scalar absolute difference value as score with a reduction based on sum/mean or the spatial map of
+ absolute difference
+ """
+
+ # The background utils is only applicable here because instead of Batch-dimension we have repeats here
+ y_pred = y_pred.float()
+ y = y.float()
+
+ if not include_background:
+ y_pred, y = ignore_background(y_pred=y_pred, y=y)
+
+ n_len = len(y_pred.shape)
+ if n_len < 4 and scalar_reduction == "none":
+ warnings.warn("Reduction set to None, Spatial map return requires a 2D/3D image of B-Batchsize and C-channels")
+ return None
+
+ abs_diff_map = torch.abs(y_pred - y)
+
+ if scalar_reduction == MetricReduction.NONE:
+ return abs_diff_map
+
+ if scalar_reduction == MetricReduction.MEAN:
+ return torch.mean(abs_diff_map, dim=list(range(1, n_len)))
+ if scalar_reduction == MetricReduction.SUM:
+ return torch.sum(abs_diff_map, dim=list(range(1, n_len)))
+ raise ValueError(f"scalar_reduction={scalar_reduction} not supported.")
diff --git a/monai/metrics/confusion_matrix.py b/monai/metrics/confusion_matrix.py
index cdde195d3ac..da8561f45cb 100644
--- a/monai/metrics/confusion_matrix.py
+++ b/monai/metrics/confusion_matrix.py
@@ -100,9 +100,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
return get_confusion_matrix(y_pred=y_pred, y=y, include_background=self.include_background)
- def aggregate( # type: ignore
- self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None
- ):
+ def aggregate(self, compute_sample: bool = False, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction for the confusion matrix values.
diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py
index 768841f6c76..b099cdc2a4f 100644
--- a/monai/metrics/cumulative_average.py
+++ b/monai/metrics/cumulative_average.py
@@ -9,64 +9,150 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import torch
+import warnings
+from typing import Any, Optional
-from monai.transforms import isnan
-from monai.utils import convert_data_type
+import torch
+import torch.distributed as dist
-from .metric import Cumulative
+from monai.config import NdarrayOrTensor
-class CumulativeAverage(Cumulative):
+class CumulativeAverage:
"""
- Cumulatively record data value and aggregate for the average value.
- It supports single class or multi-class data, for example,
- value can be 0.44 (a loss value) or [0.3, 0.4] (metrics of two classes).
- It also supports distributed data parallel, sync data when aggregating.
- For example, recording loss values and compute the overall average value in every 5 iterations:
+ A utility class to keep track of average values. For example during training/validation loop,
+ we need to accumulate the per-batch metrics and calculate the final average value for the whole dataset.
+ When training in multi-gpu environment, with DistributedDataParallel, it will average across the processes.
+
+ Example:
.. code-block:: python
- average = CumulativeAverage()
- for i, d in enumerate(dataloader):
- loss = ...
- average.append(loss)
- if i % 5 == 0:
- print(f"cumulative average of loss: {average.aggregate()}")
- average.reset()
+ from monai.metrics import CumulativeAverage
+
+ run_avg = CumulativeAverage()
+ batch_size = 8
+ for i in range(len(train_set)):
+ ...
+ val = calc_metric(x,y) #some metric value
+ run_avg.append(val, count=batch_size)
+
+ val_avg = run_avg.aggregate() #average value
"""
def __init__(self) -> None:
- super().__init__()
- self.sum = None
- self.not_nans = None
+ self.reset()
- def reset(self):
+ def reset(self) -> None:
"""
- Reset all the running status, including buffers, sum, not nans count, etc.
+ Reset all stats
+ """
+ self.val: torch.Tensor = None # type: ignore
+ self.sum = torch.tensor(0, dtype=torch.float)
+ self.count = torch.tensor(0, dtype=torch.float)
+ self.is_distributed = dist.is_available() and dist.is_initialized()
+ def get_current(self, to_numpy: bool = True) -> NdarrayOrTensor:
"""
- super().reset()
- self.sum = None
- self.not_nans = None
+ returns the most recent value (averaged across processes)
- def aggregate(self):
+ Args:
+ to_numpy: whether to convert to numpy array. Defaults to True
"""
- Sync data from all the ranks and compute the average value with previous sum value.
+ if self.val is None:
+ return 0
+
+ val = self.val.clone()
+ val[~torch.isfinite(val)] = 0
+
+ if self.is_distributed:
+ val = val / dist.get_world_size()
+ dist.all_reduce(val)
+
+ if to_numpy:
+ val = val.cpu().numpy()
+ return val
+
+ def aggregate(self, to_numpy: bool = True) -> NdarrayOrTensor:
+ """
+ returns the total average value (averaged across processes)
+
+ Args:
+ to_numpy: whether to convert to numpy array. Defaults to True
"""
- data = self.get_buffer()
+ if self.val is None:
+ return 0
+
+ sum = self.sum
+ count = self.count
- # compute SUM across the batch dimension
- nans = isnan(data)
- not_nans = convert_data_type((~nans), dtype=torch.float32)[0].sum(0)
- data[nans] = 0
- f = data.sum(0)
+ if self.is_distributed:
+ sum = sum.to(self.val, copy=True)
+ count = count.to(self.val, copy=True)
+ dist.all_reduce(sum)
+ dist.all_reduce(count)
- # clear the buffer for next update
- super().reset()
- self.sum = f if self.sum is None else (self.sum + f)
- self.not_nans = not_nans if self.not_nans is None else (self.not_nans + not_nans)
+ val = torch.where(count > 0, sum / count, sum)
- return self.sum / self.not_nans
+ if to_numpy:
+ val = val.cpu().numpy()
+ return val
+
+ def append(self, val: Any, count: Optional[Any] = 1) -> None:
+ """
+ Append with a new value, and an optional count. Any data type is supported that is convertable
+ with torch.as_tensor() e.g. number, list, numpy array, or Tensor.
+
+ Args:
+ val: value (e.g. number, list, numpy array or Tensor) to keep track of
+ count: count (e.g. number, list, numpy array or Tensor), to update the contribution count
+
+ For example:
+ # a simple constant tracking
+ avg = CumulativeAverage()
+ avg.append(0.6)
+ avg.append(0.8)
+ print(avg.aggregate()) #prints 0.7
+
+ # an array tracking, e.g. metrics from 3 classes
+ avg= CumulativeAverage()
+ avg.append([0.2, 0.4, 0.4])
+ avg.append([0.4, 0.6, 0.4])
+ print(avg.aggregate()) #prints [0.3, 0.5. 0.4]
+
+ # different contributions / counts
+ avg= CumulativeAverage()
+ avg.append(1, count=4) #avg metric 1 coming from a batch of 4
+ avg.append(2, count=6) #avg metric 2 coming from a batch of 6
+ print(avg.aggregate()) #prints 1.6 == (1*4 +2*6)/(4+6)
+
+ # different contributions / counts
+ avg= CumulativeAverage()
+ avg.append([0.5, 0.5, 0], count=[1, 1, 0]) # last elements count is zero to ignore it
+ avg.append([0.5, 0.5, 0.5], count=[1, 1, 1]) #
+ print(avg.aggregate()) #prints [0.5, 0.5, 0,5] == ([0.5, 0.5, 0] + [0.5, 0.5, 0.5]) / ([1, 1, 0] + [1, 1, 1])
+
+ """
+ self.val = torch.as_tensor(val, dtype=torch.float)
+ if self.val.requires_grad:
+ self.val = self.val.detach().clone()
+
+ count = torch.as_tensor(count, dtype=torch.float, device="cpu")
+ if count.ndim > 0 and count.shape != self.val.shape:
+ raise ValueError(
+ f"Count shape must match val shape, unless count is a single number: {count} val {self.val.cpu()}"
+ )
+
+ val = count * self.val.cpu()
+
+ # account for possible non-finite numbers in val and replace them with 0s
+ nfin = torch.isfinite(val)
+ if not torch.all(nfin):
+ warnings.warn(f"non-finite inputs received: val: {val}, count: {count}")
+ count = torch.where(nfin, count, torch.zeros_like(count))
+ val = torch.where(nfin, val, torch.zeros_like(val))
+
+ self.count = self.count + count
+ self.sum = self.sum + val
diff --git a/monai/metrics/froc.py b/monai/metrics/froc.py
index 93ad625b90e..56e0755b99c 100644
--- a/monai/metrics/froc.py
+++ b/monai/metrics/froc.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import List, Optional, Tuple, Union
import numpy as np
diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py
index f223664bea2..3a0e90d5872 100644
--- a/monai/metrics/generalized_dice.py
+++ b/monai/metrics/generalized_dice.py
@@ -80,7 +80,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction logic for the output of `compute_generalized_dice`.
diff --git a/monai/metrics/hausdorff_distance.py b/monai/metrics/hausdorff_distance.py
index 61bea4c87d7..54de8b1d4d4 100644
--- a/monai/metrics/hausdorff_distance.py
+++ b/monai/metrics/hausdorff_distance.py
@@ -104,7 +104,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
directed=self.directed,
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction logic for the output of `compute_hausdorff_distance`.
diff --git a/monai/metrics/meandice.py b/monai/metrics/meandice.py
index 30ef0845c73..a9d4e7182a9 100644
--- a/monai/metrics/meandice.py
+++ b/monai/metrics/meandice.py
@@ -14,7 +14,7 @@
import torch
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
-from monai.utils import MetricReduction
+from monai.utils import MetricReduction, deprecated
from .metric import CumulativeIterationMetric
@@ -80,11 +80,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
if dims < 3:
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
# compute dice (BxC) for each channel for each batch
- return compute_meandice(
+ return compute_dice(
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction logic for the output of `compute_meandice`.
@@ -103,10 +103,10 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # typ
return (f, not_nans) if self.get_not_nans else f
-def compute_meandice(
+def compute_dice(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
) -> torch.Tensor:
- """Computes Dice score metric from full size Tensor and collects average.
+ """Computes Dice score metric for a batch of predictions.
Args:
y_pred: input data to compute, typical segmentation model output.
@@ -146,6 +146,11 @@ def compute_meandice(
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
denominator = y_o + y_pred_o
- if ignore_empty is True:
+ if ignore_empty:
return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
return torch.where(denominator > 0, (2.0 * intersection) / denominator, torch.tensor(1.0, device=y_o.device))
+
+
+@deprecated(since="1.0.0", msg_suffix="use `compute_dice` instead.")
+def compute_meandice(*args, **kwargs):
+ return compute_dice(*args, **kwargs)
diff --git a/monai/metrics/meaniou.py b/monai/metrics/meaniou.py
index 8b07552a8c3..55fa73e1ffc 100644
--- a/monai/metrics/meaniou.py
+++ b/monai/metrics/meaniou.py
@@ -14,14 +14,15 @@
import torch
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
-from monai.utils import MetricReduction
+from monai.utils import MetricReduction, deprecated
from .metric import CumulativeIterationMetric
class MeanIoU(CumulativeIterationMetric):
"""
- Compute average IoU score between two tensors. It can support both multi-classes and multi-labels tasks.
+ Compute average Intersection over Union (IoU) score between two tensors.
+ It supports both multi-classes and multi-labels tasks.
Input `y_pred` is compared with ground truth `y`.
`y_pred` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
in ``monai.transforms.post`` first to achieve binarized values.
@@ -80,11 +81,11 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
if dims < 3:
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
# compute IoU (BxC) for each channel for each batch
- return compute_meaniou(
+ return compute_iou(
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction logic for the output of `compute_meaniou`.
@@ -103,10 +104,10 @@ def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # typ
return (f, not_nans) if self.get_not_nans else f
-def compute_meaniou(
+def compute_iou(
y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, ignore_empty: bool = True
) -> torch.Tensor:
- """Computes IoU score metric from full size Tensor and collects average.
+ """Computes Intersection over Union (IoU) score metric from a batch of predictions.
Args:
y_pred: input data to compute, typical segmentation model output.
@@ -146,6 +147,11 @@ def compute_meaniou(
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
union = y_o + y_pred_o - intersection
- if ignore_empty is True:
+ if ignore_empty:
return torch.where(y_o > 0, (intersection) / union, torch.tensor(float("nan"), device=y_o.device))
return torch.where(union > 0, (intersection) / union, torch.tensor(1.0, device=y_o.device))
+
+
+@deprecated(since="1.0.0", msg_suffix="use `compute_iou` instead.")
+def compute_meaniou(*args, **kwargs):
+ return compute_iou(*args, **kwargs)
diff --git a/monai/metrics/metric.py b/monai/metrics/metric.py
index fa8b3354de1..e92ed73dd64 100644
--- a/monai/metrics/metric.py
+++ b/monai/metrics/metric.py
@@ -45,7 +45,7 @@ class IterationMetric(Metric):
Subclasses typically implement the `_compute_tensor` function for the actual tensor computation logic.
"""
- def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore
+ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None):
"""
Execute basic computation for model prediction `y_pred` and ground truth `y` (optional).
It supports inputs of a list of "channel-first" Tensor and a "batch-first" Tensor.
@@ -310,7 +310,7 @@ class CumulativeIterationMetric(Cumulative, IterationMetric):
"""
- def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None): # type: ignore
+ def __call__(self, y_pred: TensorOrList, y: Optional[TensorOrList] = None):
"""
Execute basic computation for model prediction and ground truth.
It can support both `list of channel-first Tensor` and `batch-first Tensor`.
diff --git a/monai/metrics/panoptic_quality.py b/monai/metrics/panoptic_quality.py
new file mode 100644
index 00000000000..4bf87188d5e
--- /dev/null
+++ b/monai/metrics/panoptic_quality.py
@@ -0,0 +1,292 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Sequence, Union
+
+import torch
+
+from monai.metrics.metric import CumulativeIterationMetric
+from monai.metrics.utils import do_metric_reduction, remap_instance_id
+from monai.utils import MetricReduction, ensure_tuple, optional_import
+
+linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment")
+
+__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"]
+
+
+class PanopticQualityMetric(CumulativeIterationMetric):
+ """
+ Compute Panoptic Quality between two instance segmentation masks. If specifying `metric_name` to "SQ" or "RQ",
+ Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.
+
+ Panoptic Quality is a metric used in panoptic segmentation tasks. This task unifies the typically distinct tasks
+ of semantic segmentation (assign a class label to each pixel) and
+ instance segmentation (detect and segment each object instance). Compared with semantic segmentation, panoptic
+ segmentation distinguish different instances that belong to same class.
+ Compared with instance segmentation, panoptic segmentation does not allow overlap and only one semantic label and
+ one instance id can be assigned to each pixel.
+ Please refer to the following paper for more details:
+ https://openaccess.thecvf.com/content_CVPR_2019/papers/Kirillov_Panoptic_Segmentation_CVPR_2019_paper.pdf
+
+ This class also refers to the following implementation:
+ https://github.com/TissueImageAnalytics/CoNIC
+
+ Args:
+ num_classes: number of classes. The number should not count the background.
+ metric_name: output metric. The value can be "pq", "sq" or "rq".
+ Except for input only one metric, multiple metrics are also supported via input a sequence of metric names
+ such as ("pq", "sq", "rq"). If input a sequence, a list of results with the same order
+ as the input names will be returned.
+ reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
+ ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
+ match_iou_threshold: IOU threshould to determine the pairing between `y_pred` and `y`. Usually,
+ it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.
+ If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
+ maximal amout of unique pairing.
+ smooth_numerator: a small constant added to the numerator to avoid zero.
+
+ """
+
+ def __init__(
+ self,
+ num_classes: int,
+ metric_name: Union[Sequence[str], str] = "pq",
+ reduction: Union[MetricReduction, str] = MetricReduction.MEAN_BATCH,
+ match_iou_threshold: float = 0.5,
+ smooth_numerator: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.num_classes = num_classes
+ self.reduction = reduction
+ self.match_iou_threshold = match_iou_threshold
+ self.smooth_numerator = smooth_numerator
+ self.metric_name = ensure_tuple(metric_name)
+
+ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore
+ """
+ Args:
+ y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
+ second channel represent the instance predictions and classification predictions respectively.
+ y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
+ second channel represent the instance labels and classification labels respectively.
+ Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
+ where 0 represents the background.
+
+ Raises:
+ ValueError: when `y_pred` and `y` have different shapes.
+ ValueError: when `y_pred` and `y` have != 2 channels.
+ ValueError: when `y_pred` and `y` have != 4 dimensions.
+
+ """
+ if y_pred.shape != y.shape:
+ raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
+
+ if y_pred.shape[1] != 2:
+ raise ValueError(
+ f"for panoptic quality calculation, only 2 channels input is supported, got {y_pred.shape[1]}."
+ )
+
+ dims = y_pred.ndimension()
+ if dims != 4:
+ raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.")
+
+ batch_size = y_pred.shape[0]
+
+ outputs = torch.zeros([batch_size, self.num_classes, 4], device=y_pred.device)
+
+ for b in range(batch_size):
+ true_instance, pred_instance = y[b, 0], y_pred[b, 0]
+ true_class, pred_class = y[b, 1], y_pred[b, 1]
+ for c in range(self.num_classes):
+ pred_instance_c = (pred_class == c + 1) * pred_instance
+ true_instance_c = (true_class == c + 1) * true_instance
+
+ outputs[b, c] = compute_panoptic_quality(
+ pred=pred_instance_c,
+ gt=true_instance_c,
+ remap=True,
+ match_iou_threshold=self.match_iou_threshold,
+ output_confusion_matrix=True,
+ )
+
+ return outputs
+
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
+ """
+ Execute reduction logic for the output of `compute_panoptic_quality`.
+
+ Args:
+ reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
+ available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
+ ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
+
+ """
+ data = self.get_buffer()
+ if not isinstance(data, torch.Tensor):
+ raise ValueError("the data to aggregate must be PyTorch Tensor.")
+
+ # do metric reduction
+ f, _ = do_metric_reduction(data, reduction or self.reduction)
+ tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]
+ results = []
+ for metric_name in self.metric_name:
+ metric_name = _check_panoptic_metric_name(metric_name)
+ if metric_name == "rq":
+ results.append(tp / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))
+ elif metric_name == "sq":
+ results.append(iou_sum / (tp + self.smooth_numerator))
+ else:
+ results.append(iou_sum / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))
+
+ return results[0] if len(results) == 1 else results
+
+
+def compute_panoptic_quality(
+ pred: torch.Tensor,
+ gt: torch.Tensor,
+ metric_name: str = "pq",
+ remap: bool = True,
+ match_iou_threshold: float = 0.5,
+ smooth_numerator: float = 1e-6,
+ output_confusion_matrix: bool = False,
+):
+ """Computes Panoptic Quality (PQ). If specifying `metric_name` to "SQ" or "RQ",
+ Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.
+
+ In addition, if `output_confusion_matrix` is True, the function will return a tensor with shape 4, which
+ represents the true positive, false positive, false negative and the sum of iou. These four values are used to
+ calculate PQ, and returning them directly enables further calculation over all images.
+
+ Args:
+ pred: input data to compute, it must be in the form of HW and have integer type.
+ gt: ground truth. It must have the same shape as `pred` and have integer type.
+ metric_name: output metric. The value can be "pq", "sq" or "rq".
+ remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
+ match_iou_threshold: IOU threshould to determine the pairing between `pred` and `gt`. Usually,
+ it should >= 0.5, the pairing between instances of `pred` and `gt` are identical.
+ If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
+ maximal amout of unique pairing.
+ smooth_numerator: a small constant added to the numerator to avoid zero.
+
+ Raises:
+ ValueError: when `pred` and `gt` have different shapes.
+ ValueError: when `match_iou_threshold` <= 0.0 or > 1.0.
+
+ """
+
+ if gt.shape != pred.shape:
+ raise ValueError(f"pred and gt should have same shapes, got {pred.shape} and {gt.shape}.")
+ if match_iou_threshold <= 0.0 or match_iou_threshold > 1.0:
+ raise ValueError(f"'match_iou_threshold' should be within (0, 1], got: {match_iou_threshold}.")
+
+ gt = gt.int()
+ pred = pred.int()
+
+ if remap is True:
+ gt = remap_instance_id(gt)
+ pred = remap_instance_id(pred)
+
+ pairwise_iou, true_id_list, pred_id_list = _get_pairwise_iou(pred, gt, device=pred.device)
+ paired_iou, paired_true, paired_pred = _get_paired_iou(
+ pairwise_iou, match_iou_threshold, device=pairwise_iou.device
+ )
+
+ unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
+ unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
+
+ tp, fp, fn = len(paired_true), len(unpaired_pred), len(unpaired_true)
+ iou_sum = paired_iou.sum()
+
+ if output_confusion_matrix:
+ return torch.as_tensor([tp, fp, fn, iou_sum], device=pred.device)
+
+ metric_name = _check_panoptic_metric_name(metric_name)
+ if metric_name == "rq":
+ return torch.as_tensor(tp / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)
+ if metric_name == "sq":
+ return torch.as_tensor(iou_sum / (tp + smooth_numerator), device=pred.device)
+ return torch.as_tensor(iou_sum / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)
+
+
+def _get_id_list(gt: torch.Tensor):
+ id_list = list(gt.unique())
+ # ensure id 0 is included
+ if 0 not in id_list:
+ id_list.insert(0, torch.tensor(0).int())
+
+ return id_list
+
+
+def _get_pairwise_iou(pred: torch.Tensor, gt: torch.Tensor, device: Union[str, torch.device] = "cpu"):
+ pred_id_list = _get_id_list(pred)
+ true_id_list = _get_id_list(gt)
+
+ pairwise_iou = torch.zeros([len(true_id_list) - 1, len(pred_id_list) - 1], dtype=torch.float, device=device)
+ true_masks: List[torch.Tensor] = []
+ pred_masks: List[torch.Tensor] = []
+
+ for t in true_id_list[1:]:
+ t_mask = torch.as_tensor(gt == t, device=device).int()
+ true_masks.append(t_mask)
+
+ for p in pred_id_list[1:]:
+ p_mask = torch.as_tensor(pred == p, device=device).int()
+ pred_masks.append(p_mask)
+
+ for true_id in range(1, len(true_id_list)):
+ t_mask = true_masks[true_id - 1]
+ pred_true_overlap = pred[t_mask > 0]
+ pred_true_overlap_id = list(pred_true_overlap.unique())
+ for pred_id in pred_true_overlap_id:
+ if pred_id == 0:
+ continue
+ p_mask = pred_masks[pred_id - 1]
+ total = (t_mask + p_mask).sum()
+ inter = (t_mask * p_mask).sum()
+ iou = inter / (total - inter)
+ pairwise_iou[true_id - 1, pred_id - 1] = iou
+
+ return pairwise_iou, true_id_list, pred_id_list
+
+
+def _get_paired_iou(
+ pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: Union[str, torch.device] = "cpu"
+):
+ if match_iou_threshold >= 0.5:
+ pairwise_iou[pairwise_iou <= match_iou_threshold] = 0.0
+ paired_true, paired_pred = torch.nonzero(pairwise_iou)[:, 0], torch.nonzero(pairwise_iou)[:, 1]
+ paired_iou = pairwise_iou[paired_true, paired_pred]
+ paired_true += 1
+ paired_pred += 1
+
+ return paired_iou, paired_true, paired_pred
+
+ pairwise_iou = pairwise_iou.cpu().numpy()
+ paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
+ paired_iou = pairwise_iou[paired_true, paired_pred]
+ paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)
+ paired_pred = torch.as_tensor(list(paired_pred[paired_iou > match_iou_threshold] + 1), device=device)
+ paired_iou = paired_iou[paired_iou > match_iou_threshold]
+
+ return paired_iou, paired_true, paired_pred
+
+
+def _check_panoptic_metric_name(metric_name: str):
+ metric_name = metric_name.replace(" ", "_")
+ metric_name = metric_name.lower()
+ if metric_name in ["panoptic_quality", "pq"]:
+ return "pq"
+ if metric_name in ["segmentation_quality", "sq"]:
+ return "sq"
+ if metric_name in ["recognition_quality", "rq"]:
+ return "rq"
+ raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.")
diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py
index d1cd44e4bb1..1c48ded306e 100644
--- a/monai/metrics/regression.py
+++ b/monai/metrics/regression.py
@@ -48,7 +48,7 @@ def __init__(
self.reduction = reduction
self.get_not_nans = get_not_nans
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Args:
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py
index 0b3e488922f..2bb8dc2b32b 100644
--- a/monai/metrics/rocauc.py
+++ b/monai/metrics/rocauc.py
@@ -51,7 +51,7 @@ def __init__(self, average: Union[Average, str] = Average.MACRO) -> None:
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignore
return y_pred, y
- def aggregate(self, average: Union[Average, str, None] = None): # type: ignore
+ def aggregate(self, average: Union[Average, str, None] = None):
"""
Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
This function reads the buffers and computes the area under the ROC.
diff --git a/monai/metrics/surface_dice.py b/monai/metrics/surface_dice.py
index 8bc34d4afcd..80869ce5831 100644
--- a/monai/metrics/surface_dice.py
+++ b/monai/metrics/surface_dice.py
@@ -86,7 +86,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
distance_metric=self.distance_metric,
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
r"""
Aggregates the output of `_compute_tensor`.
diff --git a/monai/metrics/surface_distance.py b/monai/metrics/surface_distance.py
index e4637024589..8bb688b4e02 100644
--- a/monai/metrics/surface_distance.py
+++ b/monai/metrics/surface_distance.py
@@ -94,7 +94,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor): # type: ignor
distance_metric=self.distance_metric,
)
- def aggregate(self, reduction: Union[MetricReduction, str, None] = None): # type: ignore
+ def aggregate(self, reduction: Union[MetricReduction, str, None] = None):
"""
Execute reduction logic for the output of `compute_average_surface_distance`.
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index c17df7a54ae..0c06c7768a9 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -211,9 +211,46 @@ def is_binary_tensor(input: torch.Tensor, name: str):
ValueError: if `input` is not a PyTorch Tensor.
Returns:
- Union[str, None]: warning message, if the tensor is not binary. Othwerwise, None.
+ Union[str, None]: warning message, if the tensor is not binary. Otherwise, None.
"""
if not isinstance(input, torch.Tensor):
raise ValueError(f"{name} must be of type PyTorch Tensor.")
if not torch.all(input.byte() == input) or input.max() > 1 or input.min() < 0:
warnings.warn(f"{name} should be a binarized tensor.")
+
+
+def remap_instance_id(pred: torch.Tensor, by_size: bool = False):
+ """
+ This function is used to rename all instance id of `pred`, so that the id is
+ contiguous.
+ For example: all ids of the input can be [0, 1, 2] rather than [0, 2, 5].
+ This function is helpful for calculating metrics like Panoptic Quality (PQ).
+ The implementation refers to:
+
+ https://github.com/vqdang/hover_net
+
+ Args:
+ pred: segmentation predictions in the form of torch tensor. Each
+ value of the tensor should be an integer, and represents the prediction of its corresponding instance id.
+ by_size: if True, larget instance will be assigned a smaller id.
+
+ """
+ pred_id = list(pred.unique())
+ # the original implementation has the limitation that if there is no 0 in pred, error will happen
+ pred_id = [i for i in pred_id if i != 0]
+
+ if len(pred_id) == 0:
+ return pred
+ if by_size is True:
+ instance_size = []
+ for instance_id in pred_id:
+ instance_size.append((pred == instance_id).sum())
+
+ pair_data = zip(pred_id, instance_size)
+ pair_list = sorted(pair_data, key=lambda x: x[1], reverse=True) # type: ignore
+ pred_id, _ = zip(*pair_list)
+
+ new_pred = torch.zeros_like(pred, dtype=torch.int)
+ for idx, instance_id in enumerate(pred_id):
+ new_pred[pred == instance_id] = idx + 1
+ return new_pred
diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py
index 0543b116327..b2dc907c185 100644
--- a/monai/networks/__init__.py
+++ b/monai/networks/__init__.py
@@ -15,6 +15,7 @@
eval_mode,
get_state_dict,
icnr_init,
+ look_up_named_module,
normal_init,
normalize_transform,
one_hot,
@@ -23,6 +24,7 @@
replace_modules,
replace_modules_temp,
save_state,
+ set_named_module,
slice_channels,
to_norm_affine,
train_mode,
diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py
index 27feffea107..61d5fcf8f5d 100644
--- a/monai/networks/blocks/__init__.py
+++ b/monai/networks/blocks/__init__.py
@@ -15,9 +15,11 @@
from .backbone_fpn_utils import BackboneWithFPN
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
+from .denseblock import ConvDenseBlock, DenseBlock
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding
+from .encoder import BaseEncoder
from .fcn import FCN, GCN, MCFCN, Refine
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
diff --git a/monai/networks/blocks/backbone_fpn_utils.py b/monai/networks/blocks/backbone_fpn_utils.py
index c663485583e..145a4ac2e13 100644
--- a/monai/networks/blocks/backbone_fpn_utils.py
+++ b/monai/networks/blocks/backbone_fpn_utils.py
@@ -43,7 +43,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""
This script is modified from from torchvision to support N-D images,
by overriding the definition of convolutional layers and pooling layers.
diff --git a/monai/networks/blocks/convolutions.py b/monai/networks/blocks/convolutions.py
index 17d3c9d0849..55735e2f582 100644
--- a/monai/networks/blocks/convolutions.py
+++ b/monai/networks/blocks/convolutions.py
@@ -18,7 +18,6 @@
from monai.networks.blocks import ADN
from monai.networks.layers.convutils import same_padding, stride_minus_kernel_padding
from monai.networks.layers.factories import Conv
-from monai.utils.deprecate_utils import deprecated_arg
class Convolution(nn.Sequential):
@@ -38,7 +37,7 @@ class Convolution(nn.Sequential):
from monai.networks.blocks import Convolution
conv = Convolution(
- dimensions=3,
+ spatial_dims=3,
in_channels=1,
out_channels=1,
adn_ordering="ADN",
@@ -76,7 +75,7 @@ class Convolution(nn.Sequential):
- When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
- When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
- The value of dropout_dim should be no no larger than the value of `spatial_dims`.
+ The value of dropout_dim should be no larger than the value of `spatial_dims`.
dilation: dilation rate. Defaults to 1.
groups: controls the connections between inputs and outputs. Defaults to 1.
bias: whether to have a bias term. Defaults to True.
@@ -87,9 +86,6 @@ class Convolution(nn.Sequential):
output_padding: controls the additional size added to one side of the output shape.
Defaults to None.
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
-
See also:
:py:class:`monai.networks.layers.Conv`
@@ -97,9 +93,6 @@ class Convolution(nn.Sequential):
"""
- @deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def __init__(
self,
spatial_dims: int,
@@ -119,16 +112,15 @@ def __init__(
is_transposed: bool = False,
padding: Optional[Union[Sequence[int], int]] = None,
output_padding: Optional[Union[Sequence[int], int]] = None,
- dimensions: Optional[int] = None,
) -> None:
super().__init__()
- self.dimensions = spatial_dims if dimensions is None else dimensions
+ self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.is_transposed = is_transposed
if padding is None:
padding = same_padding(kernel_size, dilation)
- conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.dimensions]
+ conv_type = Conv[Conv.CONVTRANS if is_transposed else Conv.CONV, self.spatial_dims]
conv: nn.Module
if is_transposed:
@@ -170,7 +162,7 @@ def __init__(
in_channels=out_channels,
act=act,
norm=norm,
- norm_dim=self.dimensions,
+ norm_dim=self.spatial_dims,
dropout=dropout,
dropout_dim=dropout_dim,
),
@@ -237,7 +229,7 @@ class ResidualUnit(nn.Module):
- When dropout_dim = 2, Randomly zero out entire channels (a channel is a 2D feature map).
- When dropout_dim = 3, Randomly zero out entire channels (a channel is a 3D feature map).
- The value of dropout_dim should be no no larger than the value of `dimensions`.
+ The value of dropout_dim should be no larger than the value of `dimensions`.
dilation: dilation rate. Defaults to 1.
bias: whether to have a bias term. Defaults to True.
last_conv_only: for the last subunit, whether to use the convolutional layer only.
@@ -245,16 +237,12 @@ class ResidualUnit(nn.Module):
padding: controls the amount of implicit zero-paddings on both sides for padding number of points
for each dimension. Defaults to None.
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
-
See also:
:py:class:`monai.networks.blocks.Convolution`
"""
- @deprecated_arg(name="dimensions", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
def __init__(
self,
spatial_dims: int,
@@ -272,10 +260,9 @@ def __init__(
bias: bool = True,
last_conv_only: bool = False,
padding: Optional[Union[Sequence[int], int]] = None,
- dimensions: Optional[int] = None,
) -> None:
super().__init__()
- self.dimensions = spatial_dims if dimensions is None else dimensions
+ self.spatial_dims = spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Sequential()
@@ -289,7 +276,7 @@ def __init__(
for su in range(subunits):
conv_only = last_conv_only and su == (subunits - 1)
unit = Convolution(
- self.dimensions,
+ self.spatial_dims,
schannels,
out_channels,
strides=sstrides,
@@ -320,7 +307,7 @@ def __init__(
rkernel_size = 1
rpadding = 0
- conv_type = Conv[Conv.CONV, self.dimensions]
+ conv_type = Conv[Conv.CONV, self.spatial_dims]
self.residual = conv_type(in_channels, out_channels, rkernel_size, strides, rpadding, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
diff --git a/monai/networks/blocks/denseblock.py b/monai/networks/blocks/denseblock.py
new file mode 100644
index 00000000000..dafd8d03a61
--- /dev/null
+++ b/monai/networks/blocks/denseblock.py
@@ -0,0 +1,130 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks import Convolution, ResidualUnit
+from monai.networks.layers.factories import Act, Norm
+
+__ALL__ = ["DenseBlock", "ConvDenseBlock"]
+
+
+class DenseBlock(nn.Sequential):
+ """
+ A DenseBlock is a sequence of layers where each layer's outputs are concatenated with their inputs. This has the
+ effect of accumulating outputs from previous layers as inputs to later ones and as the final output of the block.
+
+ Args:
+ layers: sequence of nn.Module objects to define the individual layers of the dense block
+ """
+
+ def __init__(self, layers: Sequence[nn.Module]):
+ super().__init__()
+ for i, l in enumerate(layers):
+ self.add_module(f"layers{i}", l)
+
+ def forward(self, x):
+ for l in self.children():
+ result = l(x)
+ x = torch.cat([x, result], 1)
+
+ return x
+
+
+class ConvDenseBlock(DenseBlock):
+ """
+ This dense block is defined as a sequence of `Convolution` or `ResidualUnit` blocks. The `_get_layer` method returns
+ an object for each layer and can be overridden to change the composition of the block.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ channels: output channels for each layer.
+ dilations: dilation value for each layer.
+ kernel_size: convolution kernel size. Defaults to 3.
+ num_res_units: number of convolutions. Defaults to 2.
+ adn_ordering: a string representing the ordering of activation, normalization, and dropout. Defaults to "NDA".
+ act: activation type and arguments. Defaults to PReLU.
+ norm: feature normalization type and arguments. Defaults to instance norm.
+ dropout: dropout ratio. Defaults to no dropout.
+ bias: whether to have a bias term. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ channels: Sequence[int],
+ dilations: Optional[Sequence[int]] = None,
+ kernel_size: Union[Sequence[int], int] = 3,
+ num_res_units: int = 0,
+ adn_ordering: str = "NDA",
+ act: Optional[Union[Tuple, str]] = Act.PRELU,
+ norm: Optional[Union[Tuple, str]] = Norm.INSTANCE,
+ dropout: Optional[int] = None,
+ bias: bool = True,
+ ):
+
+ self.spatial_dims = spatial_dims
+ self.kernel_size = kernel_size
+ self.num_res_units = num_res_units
+ self.adn_ordering = adn_ordering
+ self.act = act
+ self.norm = norm
+ self.dropout = dropout
+ self.bias = bias
+
+ l_channels = in_channels
+ dilations = dilations if dilations is not None else ([1] * len(channels))
+ layers = []
+
+ if len(channels) != len(dilations):
+ raise ValueError("Length of `channels` and `dilations` must match")
+
+ for c, d in zip(channels, dilations):
+ layer = self._get_layer(l_channels, c, d)
+ layers.append(layer)
+ l_channels += c
+
+ super().__init__(layers)
+
+ def _get_layer(self, in_channels, out_channels, dilation):
+ if self.num_res_units > 0:
+ return ResidualUnit(
+ spatial_dims=self.spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=self.kernel_size,
+ subunits=self.num_res_units,
+ adn_ordering=self.adn_ordering,
+ act=self.act,
+ norm=self.norm,
+ dropout=self.dropout,
+ dilation=dilation,
+ bias=self.bias,
+ )
+ else:
+ return Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=self.kernel_size,
+ act=self.act,
+ norm=self.norm,
+ dropout=self.dropout,
+ dilation=dilation,
+ bias=self.bias,
+ )
diff --git a/monai/networks/blocks/dints_block.py b/monai/networks/blocks/dints_block.py
index b7365f50e33..1823845adf5 100644
--- a/monai/networks/blocks/dints_block.py
+++ b/monai/networks/blocks/dints_block.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Tuple, Union
import torch
diff --git a/monai/networks/blocks/encoder.py b/monai/networks/blocks/encoder.py
new file mode 100644
index 00000000000..b19317a2af9
--- /dev/null
+++ b/monai/networks/blocks/encoder.py
@@ -0,0 +1,84 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABCMeta, abstractmethod
+from typing import Dict, List, Tuple
+
+__all__ = ["BaseEncoder"]
+
+
+class BaseEncoder(metaclass=ABCMeta):
+ """
+ Abstract class defines interface of encoders in flexible unet.
+ Encoders in flexible unet must derive from this class. Each interface method
+ should return a list containing relative information about a series of newtworks
+ defined by encoder. For example, the efficient-net encoder implement 10 basic
+ network structures in one encoder. When calling `get_encoder_name_string_list`
+ function, a string list like ["efficientnet-b0", "efficientnet-b1" ... "efficientnet-l2"]
+ should be returned.
+ """
+
+ @classmethod
+ @abstractmethod
+ def get_encoder_parameters(cls) -> List[Dict]:
+ """
+ Get parameter list to initialize encoder networks.
+ Each parameter dict must have `spatial_dims`, `in_channels`
+ and `pretrained` parameters.
+ The reason that this function should return a list is that a
+ series of encoders can be implemented by one encoder class
+ given different initialization parameters. Each parameter dict
+ in return list should be able to initialize a unique encoder.
+ """
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def num_channels_per_output(cls) -> List[Tuple[int, ...]]:
+ """
+ Get number of output features' channels.
+ The reason that this function should return a list is that a
+ series of encoders can be implemented by one encoder class
+ given different initialization parameters. And it is possible
+ that different encoders have different output feature map
+ channels. Therefore a list of output feature map channel tuples
+ corresponding to each encoder should be returned by this method.
+ """
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def num_outputs(cls) -> List[int]:
+ """
+ Get number of outputs of encoder.
+ The reason that this function should return a list is that a
+ series of encoders can be implemented by one encoder class
+ given different initialization parameters. And it is possible
+ that different encoders have different output feature numbers.
+ Therefore a list of output feature numbers corresponding to
+ each encoder should be returned by this method.
+ """
+ raise NotImplementedError
+
+ @classmethod
+ @abstractmethod
+ def get_encoder_names(cls) -> List[str]:
+ """
+ Get the name string of encoders which will be used to initialize
+ flexible unet.
+ The reason that this function should return a list is that a
+ series of encoders can be implemented by one encoder class
+ given different initialization parameters. And a name string is
+ the key to each encoder in flexible unet backbone registry.
+ Therefore this method should return every encoder name that needs
+ to be registed in flexible unet.
+ """
+ raise NotImplementedError
diff --git a/monai/networks/blocks/feature_pyramid_network.py b/monai/networks/blocks/feature_pyramid_network.py
index 2373cfc0994..f9503212978 100644
--- a/monai/networks/blocks/feature_pyramid_network.py
+++ b/monai/networks/blocks/feature_pyramid_network.py
@@ -43,7 +43,6 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""
This script is modified from from torchvision to support N-D images,
by overriding the definition of convolutional layers and pooling layers.
@@ -258,6 +257,6 @@ def forward(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
results, names = self.extra_blocks(results, x_values, names)
# make it back an OrderedDict
- out = OrderedDict([(k, v) for k, v in zip(names, results)])
+ out = OrderedDict(list(zip(names, results)))
return out
diff --git a/monai/networks/blocks/fft_utils_t.py b/monai/networks/blocks/fft_utils_t.py
index 0d6b99d7e17..1283f05c6b7 100644
--- a/monai/networks/blocks/fft_utils_t.py
+++ b/monai/networks/blocks/fft_utils_t.py
@@ -9,13 +9,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Sequence
+from typing import List
import torch
from torch import Tensor
-from monai.utils.type_conversion import convert_data_type
-
def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor:
"""
@@ -44,7 +42,7 @@ def roll_1d(x: Tensor, shift: int, shift_dim: int) -> Tensor:
return torch.cat((right, left), dim=shift_dim)
-def roll(x: Tensor, shift: Sequence[int], shift_dims: Sequence[int]) -> Tensor:
+def roll(x: Tensor, shift: List[int], shift_dims: List[int]) -> Tensor:
"""
Similar to np.roll but applies to PyTorch Tensors
@@ -68,7 +66,7 @@ def roll(x: Tensor, shift: Sequence[int], shift_dims: Sequence[int]) -> Tensor:
return x
-def fftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor:
+def fftshift(x: Tensor, shift_dims: List[int]) -> Tensor:
"""
Similar to np.fft.fftshift but applies to PyTorch Tensors
@@ -84,18 +82,13 @@ def fftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor:
Note:
This function is called when fftshift is not available in the running pytorch version
"""
- if shift_dims is None:
- # for torch.jit.script based on the fastmri repository
- shift_dims = [0] * (x.dim())
- for i in range(1, x.dim()):
- shift_dims[i] = i
shift = [0] * len(shift_dims)
for i, dim_num in enumerate(shift_dims):
shift[i] = x.shape[dim_num] // 2
return roll(x, shift, shift_dims)
-def ifftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor:
+def ifftshift(x: Tensor, shift_dims: List[int]) -> Tensor:
"""
Similar to np.fft.ifftshift but applies to PyTorch Tensors
@@ -111,11 +104,6 @@ def ifftshift(x: Tensor, shift_dims: Optional[Sequence[int]] = None) -> Tensor:
Note:
This function is called when ifftshift is not available in the running pytorch version
"""
- if shift_dims is None:
- # for torch.jit.script based on the fastmri repository
- shift_dims = [0] * (x.dim())
- for i in range(1, x.dim()):
- shift_dims[i] = i
shift = [0] * len(shift_dims)
for i, dim_num in enumerate(shift_dims):
shift[i] = (x.shape[dim_num] + 1) // 2
@@ -151,28 +139,21 @@ def ifftn_centered_t(ksp: Tensor, spatial_dims: int, is_complex: bool = True) ->
output2 = ifftn_centered(ksp, spatial_dims=2, is_complex=True)
"""
# define spatial dims to perform ifftshift, fftshift, and ifft
- shift = tuple(range(-spatial_dims, 0))
+ shift = list(range(-spatial_dims, 0))
if is_complex:
if ksp.shape[-1] != 2:
raise ValueError(f"ksp.shape[-1] is not 2 ({ksp.shape[-1]}).")
- shift = tuple(range(-spatial_dims - 1, -1))
- dims = tuple(range(-spatial_dims, 0))
+ shift = list(range(-spatial_dims - 1, -1))
+ dims = list(range(-spatial_dims, 0))
- # apply ifft
- if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8
- x = torch.fft.ifftshift(ksp, dim=shift)
- else:
- x = ifftshift(ksp, shift)
+ x = ifftshift(ksp, shift)
if is_complex:
x = torch.view_as_real(torch.fft.ifftn(torch.view_as_complex(x), dim=dims, norm="ortho"))
else:
x = torch.view_as_real(torch.fft.ifftn(x, dim=dims, norm="ortho"))
- if hasattr(torch.fft, "fftshift"):
- out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0]
- else:
- out = convert_data_type(fftshift(x, shift), torch.Tensor)[0]
+ out: Tensor = fftshift(x, shift)
return out
@@ -206,27 +187,20 @@ def fftn_centered_t(im: Tensor, spatial_dims: int, is_complex: bool = True) -> T
output2 = fftn_centered(im, spatial_dims=2, is_complex=True)
"""
# define spatial dims to perform ifftshift, fftshift, and fft
- shift = tuple(range(-spatial_dims, 0))
+ shift = list(range(-spatial_dims, 0))
if is_complex:
if im.shape[-1] != 2:
raise ValueError(f"img.shape[-1] is not 2 ({im.shape[-1]}).")
- shift = tuple(range(-spatial_dims - 1, -1))
- dims = tuple(range(-spatial_dims, 0))
+ shift = list(range(-spatial_dims - 1, -1))
+ dims = list(range(-spatial_dims, 0))
- # apply fft
- if hasattr(torch.fft, "ifftshift"): # ifftshift was added in pytorch 1.8
- x = torch.fft.ifftshift(im, dim=shift)
- else:
- x = ifftshift(im, shift)
+ x = ifftshift(im, shift)
if is_complex:
x = torch.view_as_real(torch.fft.fftn(torch.view_as_complex(x), dim=dims, norm="ortho"))
else:
x = torch.view_as_real(torch.fft.fftn(x, dim=dims, norm="ortho"))
- if hasattr(torch.fft, "fftshift"):
- out = convert_data_type(torch.fft.fftshift(x, dim=shift), torch.Tensor)[0]
- else:
- out = convert_data_type(fftshift(x, shift), torch.Tensor)[0]
+ out: Tensor = fftshift(x, shift)
return out
diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py
index db92111d142..d0b87fda6b9 100644
--- a/monai/networks/blocks/selfattention.py
+++ b/monai/networks/blocks/selfattention.py
@@ -23,12 +23,13 @@ class SABlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
"""
- def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) -> None:
+ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
+ qkv_bias: bias term for the qkv linear layer.
"""
@@ -42,7 +43,7 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)
self.num_heads = num_heads
self.out_proj = nn.Linear(hidden_size, hidden_size)
- self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
+ self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
self.out_rearrange = Rearrange("b h l d -> b l (h d)")
self.drop_output = nn.Dropout(dropout_rate)
diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py
index 616d84e067f..88b33acb09b 100644
--- a/monai/networks/blocks/transformerblock.py
+++ b/monai/networks/blocks/transformerblock.py
@@ -21,13 +21,16 @@ class TransformerBlock(nn.Module):
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
"""
- def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0) -> None:
+ def __init__(
+ self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False
+ ) -> None:
"""
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer.
num_heads: number of attention heads.
dropout_rate: faction of the input units to drop.
+ qkv_bias: apply bias term for the qkv linear layer
"""
@@ -41,7 +44,7 @@ def __init__(self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate:
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
- self.attn = SABlock(hidden_size, num_heads, dropout_rate)
+ self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias)
self.norm2 = nn.LayerNorm(hidden_size)
def forward(self, x):
diff --git a/monai/networks/blocks/unetr_block.py b/monai/networks/blocks/unetr_block.py
index a9d871a644f..452a535a2a8 100644
--- a/monai/networks/blocks/unetr_block.py
+++ b/monai/networks/blocks/unetr_block.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Sequence, Tuple, Union
import torch
diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py
index 364db0e2363..ee03aa4e674 100644
--- a/monai/networks/blocks/upsample.py
+++ b/monai/networks/blocks/upsample.py
@@ -16,7 +16,7 @@
from monai.networks.layers.factories import Conv, Pad, Pool
from monai.networks.utils import icnr_init, pixelshuffle
-from monai.utils import InterpolateMode, UpsampleMode, deprecated_arg, ensure_tuple_rep, look_up_option
+from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
@@ -27,6 +27,7 @@ class UpSample(nn.Sequential):
Supported modes are:
- "deconv": uses a transposed convolution.
+ - "deconvgroup": uses a transposed group convolution.
- "nontrainable": uses :py:class:`torch.nn.Upsample`.
- "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
@@ -34,15 +35,13 @@ class UpSample(nn.Sequential):
(often used to map the number of features from `in_channels` to `out_channels`).
"""
- @deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def __init__(
self,
spatial_dims: int,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
scale_factor: Union[Sequence[float], float] = 2,
+ kernel_size: Optional[Union[Sequence[float], float]] = None,
size: Optional[Union[Tuple[int], int]] = None,
mode: Union[UpsampleMode, str] = UpsampleMode.DECONV,
pre_conv: Optional[Union[nn.Module, str]] = "default",
@@ -50,7 +49,6 @@ def __init__(
align_corners: Optional[bool] = True,
bias: bool = True,
apply_pad_pool: bool = True,
- dimensions: Optional[int] = None,
) -> None:
"""
Args:
@@ -58,12 +56,13 @@ def __init__(
in_channels: number of channels of the input image.
out_channels: number of channels of the output image. Defaults to `in_channels`.
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
+ kernel_size: kernel size used during transposed convolutions. Defaults to `scale_factor`.
size: spatial size of the output image.
Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,
thus if size is defined, `scale_factor` will not be used.
Defaults to None.
- mode: {``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
+ mode: {``"deconv"``, ``"deconvgroup"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
pre_conv: a conv block applied before upsampling. Defaults to "default".
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
Only used in the "nontrainable" or "pixelshuffle" mode.
@@ -80,14 +79,19 @@ def __init__(
size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
Only used in the "pixelshuffle" mode.
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
"""
super().__init__()
- if dimensions is not None:
- spatial_dims = dimensions
scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims)
up_mode = look_up_option(mode, UpsampleMode)
+
+ if not kernel_size:
+ kernel_size_ = scale_factor_
+ output_padding = padding = 0
+ else:
+ kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
+ padding = tuple((k - 1) // 2 for k in kernel_size_) # type: ignore
+ output_padding = tuple(s - 1 - (k - 1) % 2 for k, s in zip(kernel_size_, scale_factor_)) # type: ignore
+
if up_mode == UpsampleMode.DECONV:
if not in_channels:
raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
@@ -96,8 +100,31 @@ def __init__(
Conv[Conv.CONVTRANS, spatial_dims](
in_channels=in_channels,
out_channels=out_channels or in_channels,
- kernel_size=scale_factor_,
+ kernel_size=kernel_size_,
+ stride=scale_factor_,
+ padding=padding,
+ output_padding=output_padding,
+ bias=bias,
+ ),
+ )
+ elif up_mode == UpsampleMode.DECONVGROUP:
+ if not in_channels:
+ raise ValueError(f"in_channels needs to be specified in the '{mode}' mode.")
+
+ if out_channels is None:
+ out_channels = in_channels
+ groups = out_channels if in_channels % out_channels == 0 else 1
+
+ self.add_module(
+ "deconvgroup",
+ Conv[Conv.CONVTRANS, spatial_dims](
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size_,
stride=scale_factor_,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
bias=bias,
),
)
@@ -173,9 +200,6 @@ class SubpixelUpsample(nn.Module):
"""
- @deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def __init__(
self,
spatial_dims: int,
@@ -185,7 +209,6 @@ def __init__(
conv_block: Optional[Union[nn.Module, str]] = "default",
apply_pad_pool: bool = True,
bias: bool = True,
- dimensions: Optional[int] = None,
) -> None:
"""
Args:
@@ -204,15 +227,13 @@ def __init__(
component of subpixel convolutions described in Aitken et al.
bias: whether to have a bias term in the default conv_block. Defaults to True.
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
"""
super().__init__()
if scale_factor <= 0:
raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.")
- self.dimensions = spatial_dims if dimensions is None else dimensions
+ self.dimensions = spatial_dims
self.scale_factor = scale_factor
if conv_block == "default":
diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py
index 5b925258b62..7a28e863014 100644
--- a/monai/networks/blocks/warp.py
+++ b/monai/networks/blocks/warp.py
@@ -82,13 +82,21 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
else:
self._padding_mode = GridSamplePadMode(padding_mode).value
- @staticmethod
- def get_reference_grid(ddf: torch.Tensor) -> torch.Tensor:
+ self.ref_grid = None
+
+ def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor:
+ if (
+ self.ref_grid is not None
+ and self.ref_grid.shape[0] == ddf.shape[0]
+ and self.ref_grid.shape[1:] == ddf.shape[2:]
+ ):
+ return self.ref_grid # type: ignore
mesh_points = [torch.arange(0, dim) for dim in ddf.shape[2:]]
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
- grid = grid.to(ddf)
- return grid
+ self.ref_grid = grid.to(ddf)
+ self.ref_grid.requires_grad = False
+ return self.ref_grid
def forward(self, image: torch.Tensor, ddf: torch.Tensor):
"""
@@ -105,7 +113,8 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])
if ddf.shape != ddf_shape:
raise ValueError(
- f"Given input {spatial_dims}-d image shape {image.shape}, " f"the input DDF shape must be {ddf_shape}."
+ f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, "
+ f"Got {ddf.shape} instead."
)
grid = self.get_reference_grid(ddf) + ddf
grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims)
diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py
index f122dccee66..31bd36dd8fb 100644
--- a/monai/networks/layers/__init__.py
+++ b/monai/networks/layers/__init__.py
@@ -20,10 +20,12 @@
Flatten,
GaussianFilter,
HilbertTransform,
+ MedianFilter,
Reshape,
SavitzkyGolayFilter,
SkipConnection,
apply_filter,
+ median_filter,
separable_filtering,
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
diff --git a/monai/networks/layers/convutils.py b/monai/networks/layers/convutils.py
index 1e9ce954e8d..fe688b24fff 100644
--- a/monai/networks/layers/convutils.py
+++ b/monai/networks/layers/convutils.py
@@ -74,7 +74,7 @@ def calculate_out_shape(
out_shape_np = ((in_shape_np - kernel_size_np + padding_np + padding_np) // stride_np) + 1
out_shape = tuple(int(s) for s in out_shape_np)
- return out_shape if len(out_shape) > 1 else out_shape[0]
+ return out_shape
def gaussian_1d(
diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py
index 89fe1912a51..a58dbc161f2 100644
--- a/monai/networks/layers/factories.py
+++ b/monai/networks/layers/factories.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Defines factories for creating layers in generic, extensible, and dimensionally independent ways. A separate factory
object is created for each type of layer, and factory functions keyed to names are added to these objects. Whenever
@@ -70,7 +69,6 @@ def use_factory(fact_args):
InstanceNorm3dNVFuser, has_nvfuser = optional_import("apex.normalization", name="InstanceNorm3dNVFuser")
-
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
@@ -250,7 +248,7 @@ def sync_batch_factory(_dim) -> Type[nn.SyncBatchNorm]:
@Norm.factory_function("instance_nvfuser")
def instance_nvfuser_factory(dim):
"""
- `InstanceNorm3dNVFuser` is a faster verison of InstanceNorm layer and implemented in `apex`.
+ `InstanceNorm3dNVFuser` is a faster version of InstanceNorm layer and implemented in `apex`.
It only supports 3d tensors as the input. It also requires to use with CUDA and non-Windows OS.
In this function, if the required library `apex.normalization.InstanceNorm3dNVFuser` does not exist,
`nn.InstanceNorm3d` will be returned instead.
diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py
index 3de4e75766d..6b8bc65e16c 100644
--- a/monai/networks/layers/simplelayers.py
+++ b/monai/networks/layers/simplelayers.py
@@ -11,7 +11,7 @@
import math
from copy import deepcopy
-from typing import List, Sequence, Union
+from typing import List, Optional, Sequence, Union
import torch
import torch.nn.functional as F
@@ -20,8 +20,16 @@
from monai.networks.layers.convutils import gaussian_1d
from monai.networks.layers.factories import Conv
-from monai.utils import ChannelMatching, SkipMode, look_up_option, optional_import, pytorch_after
-from monai.utils.misc import issequenceiterable
+from monai.utils import (
+ ChannelMatching,
+ SkipMode,
+ convert_to_tensor,
+ ensure_tuple_rep,
+ issequenceiterable,
+ look_up_option,
+ optional_import,
+ pytorch_after,
+)
_C, _ = optional_import("monai._C")
fft, _ = optional_import("torch.fft")
@@ -32,10 +40,12 @@
"GaussianFilter",
"HilbertTransform",
"LLTM",
+ "MedianFilter",
"Reshape",
"SavitzkyGolayFilter",
"SkipConnection",
"apply_filter",
+ "median_filter",
"separable_filtering",
]
@@ -168,7 +178,6 @@ def _separable_filtering_conv(
paddings: List[int],
num_channels: int,
) -> torch.Tensor:
-
if d < 0:
return input_
@@ -290,6 +299,9 @@ def apply_filter(x: torch.Tensor, kernel: torch.Tensor, **kwargs) -> torch.Tenso
else:
# even-sized kernels are not supported
kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
+ elif kwargs["padding"] == "same" and not pytorch_after(1, 10):
+ # even-sized kernels are not supported
+ kwargs["padding"] = [(k - 1) // 2 for k in kernel.shape[2:]]
if "stride" not in kwargs:
kwargs["stride"] = 1
@@ -363,7 +375,11 @@ def _make_coeffs(window_length, order):
a = idx ** torch.arange(order + 1, dtype=torch.float, device="cpu").reshape(-1, 1)
y = torch.zeros(order + 1, dtype=torch.float, device="cpu")
y[0] = 1.0
- return torch.lstsq(y, a).solution.squeeze()
+ return (
+ torch.lstsq(y, a).solution.squeeze()
+ if not pytorch_after(1, 11)
+ else torch.linalg.lstsq(a, y).solution.squeeze()
+ )
class HilbertTransform(nn.Module):
@@ -427,6 +443,118 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(ht, device=ht.device, dtype=ht.dtype)
+def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None) -> torch.Tensor:
+ """
+ Create a binary kernel to extract the patches.
+ The window size HxWxD will create a (H*W*D)xHxWxD kernel.
+ """
+ win_size = convert_to_tensor(window_size, int, wrap_sequence=True)
+ prod = torch.prod(win_size)
+ s = [prod, 1, *win_size]
+ return torch.diag(torch.ones(prod, dtype=dtype, device=device)).view(s) # type: ignore
+
+
+def median_filter(
+ in_tensor: torch.Tensor,
+ kernel_size: Sequence[int] = (3, 3, 3),
+ spatial_dims: int = 3,
+ kernel: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> torch.Tensor:
+ """
+ Apply median filter to an image.
+
+ Args:
+ in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions.
+ kernel_size: the convolution kernel size.
+ spatial_dims: number of spatial dimensions to apply median filtering.
+ kernel: an optional customized kernel.
+ kwargs: additional parameters to the `conv`.
+
+ Returns:
+ the filtered input tensor, shape remains the same as ``in_tensor``
+
+ Example::
+
+ >>> from monai.networks.layers import median_filter
+ >>> import torch
+ >>> x = torch.rand(4, 5, 7, 6)
+ >>> output = median_filter(x, (3, 3, 3))
+ >>> output.shape
+ torch.Size([4, 5, 7, 6])
+
+ """
+ if not isinstance(in_tensor, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(in_tensor)}")
+
+ original_shape = in_tensor.shape
+ oshape, sshape = original_shape[: len(original_shape) - spatial_dims], original_shape[-spatial_dims:]
+ oprod = torch.prod(convert_to_tensor(oshape, int, wrap_sequence=True))
+ # prepare kernel
+ if kernel is None:
+ kernel_size = ensure_tuple_rep(kernel_size, spatial_dims)
+ kernel = get_binary_kernel(kernel_size, in_tensor.dtype, in_tensor.device)
+ else:
+ kernel = kernel.to(in_tensor)
+ # map the local window to single vector
+ conv = [F.conv1d, F.conv2d, F.conv3d][spatial_dims - 1]
+ reshaped_input: torch.Tensor = in_tensor.reshape(oprod, 1, *sshape) # type: ignore
+
+ # even-sized kernels are not supported
+ padding = [(k - 1) // 2 for k in reversed(kernel.shape[2:]) for _ in range(2)]
+ padded_input: torch.Tensor = F.pad(reshaped_input, pad=padding, mode="replicate")
+ features: torch.Tensor = conv(padded_input, kernel, padding=0, stride=1, **kwargs)
+
+ features = features.view(oprod, -1, *sshape) # type: ignore
+
+ # compute the median along the feature axis
+ median: torch.Tensor = torch.median(features, dim=1)[0]
+ median = median.reshape(original_shape)
+
+ return median
+
+
+class MedianFilter(nn.Module):
+ """
+ Apply median filter to an image.
+
+ Args:
+ radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3).
+
+ Returns:
+ filtered input tensor.
+
+ Example::
+
+ >>> from monai.networks.layers import MedianFilter
+ >>> import torch
+ >>> in_tensor = torch.rand(4, 5, 7, 6)
+ >>> blur = MedianFilter([1, 1, 1]) # 3x3x3 kernel
+ >>> output = blur(in_tensor)
+ >>> output.shape
+ torch.Size([4, 5, 7, 6])
+
+ """
+
+ def __init__(self, radius: Union[Sequence[int], int], spatial_dims: int = 3, device="cpu") -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.radius: Sequence[int] = ensure_tuple_rep(radius, spatial_dims)
+ self.window: Sequence[int] = [1 + 2 * deepcopy(r) for r in self.radius]
+ self.kernel = get_binary_kernel(self.window, device=device)
+
+ def forward(self, in_tensor: torch.Tensor, number_of_passes=1) -> torch.Tensor:
+ """
+ Args:
+ in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions.
+ number_of_passes: median filtering will be repeated this many times
+ """
+ x = in_tensor
+ for _ in range(number_of_passes):
+ x = median_filter(x, kernel=self.kernel, spatial_dims=self.spatial_dims)
+ return x
+
+
class GaussianFilter(nn.Module):
def __init__(
self,
diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py
index 42fac587164..a630a5edc70 100644
--- a/monai/networks/layers/utils.py
+++ b/monai/networks/layers/utils.py
@@ -11,6 +11,8 @@
from typing import Optional, Tuple, Union
+import torch.nn
+
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
from monai.utils import has_option
@@ -36,6 +38,8 @@ def get_norm_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1, cha
channels: number of features/channels when the normalization layer requires this parameter
but it is not specified in the norm parameters.
"""
+ if name == "":
+ return torch.nn.Identity()
norm_name, norm_args = split_args(name)
norm_type = Norm[norm_name, spatial_dims]
kw_args = dict(norm_args)
@@ -62,6 +66,8 @@ def get_act_layer(name: Union[Tuple, str]):
Args:
name: an activation type string or a tuple of type string and parameters.
"""
+ if name == "":
+ return torch.nn.Identity()
act_name, act_args = split_args(name)
act_type = Act[act_name]
return act_type(**act_args)
@@ -84,6 +90,8 @@ def get_dropout_layer(name: Union[Tuple, str, float, int], dropout_dim: Optional
name: a dropout ratio or a tuple of dropout type and parameters.
dropout_dim: the spatial dimension of the dropout operation.
"""
+ if name == "":
+ return torch.nn.Identity()
if isinstance(name, (int, float)):
# if dropout was specified simply as a p value, use default name and make a keyword map with the value
drop_name = Dropout.DROPOUT
@@ -111,6 +119,8 @@ def get_pool_layer(name: Union[Tuple, str], spatial_dims: Optional[int] = 1):
spatial_dims: number of spatial dimensions of the input.
"""
+ if name == "":
+ return torch.nn.Identity()
pool_name, pool_args = split_args(name)
pool_type = Pool[pool_name, spatial_dims]
return pool_type(**pool_args)
diff --git a/monai/networks/layers/weight_init.py b/monai/networks/layers/weight_init.py
index 9b81ef17f87..b0c6fae2c20 100644
--- a/monai/networks/layers/weight_init.py
+++ b/monai/networks/layers/weight_init.py
@@ -55,7 +55,7 @@ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
b: the maximum cutoff value
"""
- if not std > 0:
+ if std <= 0:
raise ValueError("the standard deviation should be greater than zero.")
if a >= b:
diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py
index a4e8312b305..18a85d802a0 100644
--- a/monai/networks/nets/__init__.py
+++ b/monai/networks/nets/__init__.py
@@ -13,6 +13,7 @@
from .attentionunet import AttentionUnet
from .autoencoder import AutoEncoder
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
+from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
from .classifier import Classifier, Critic, Discriminator
from .densenet import (
DenseNet,
@@ -37,9 +38,11 @@
EfficientNet,
EfficientNetBN,
EfficientNetBNFeatures,
+ EfficientNetEncoder,
drop_connect,
get_efficientnet_image_size,
)
+from .flexible_unet import FLEXUNET_BACKBONE, FlexibleUNet, FlexUNet, FlexUNetEncoderRegister
from .fullyconnectednet import FullyConnectedNet, VarFullyConnectedNet
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
@@ -48,8 +51,20 @@
from .netadapter import NetAdapter
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
-from .resnet import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
+from .resnet import (
+ ResNet,
+ ResNetBlock,
+ ResNetBottleneck,
+ resnet10,
+ resnet18,
+ resnet34,
+ resnet50,
+ resnet101,
+ resnet152,
+ resnet200,
+)
from .segresnet import SegResNet, SegResNetVAE
+from .segresnet_ds import SegResNetDS
from .senet import (
SENet,
SEnet,
diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py
index 177a54e105b..a57b57425ea 100644
--- a/monai/networks/nets/attentionunet.py
+++ b/monai/networks/nets/attentionunet.py
@@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
class AttentionLayer(nn.Module):
- def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0):
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ submodule: nn.Module,
+ up_kernel_size=3,
+ strides=2,
+ dropout=0.0,
+ ):
super().__init__()
self.attention = AttentionBlock(
spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
)
- self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2)
+ self.upconv = UpConv(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=in_channels,
+ strides=strides,
+ kernel_size=up_kernel_size,
+ )
self.merge = Convolution(
spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
)
@@ -174,7 +189,7 @@ class AttentionUnet(nn.Module):
channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.
strides (Sequence[int]): stride to use for convolutions.
kernel_size: convolution kernel size.
- upsample_kernel_size: convolution kernel size for transposed convolution layers.
+ up_kernel_size: convolution kernel size for transposed convolution layers.
dropout: dropout ratio. Defaults to no dropout.
"""
@@ -210,9 +225,9 @@ def __init__(
)
self.up_kernel_size = up_kernel_size
- def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module:
+ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
if len(channels) > 2:
- subblock = _create_block(channels[1:], strides[1:], level=level + 1)
+ subblock = _create_block(channels[1:], strides[1:])
return AttentionLayer(
spatial_dims=spatial_dims,
in_channels=channels[0],
@@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int =
),
subblock,
),
+ up_kernel_size=self.up_kernel_size,
+ strides=strides[0],
dropout=dropout,
)
else:
# the next layer is the bottom so stop recursion,
- # create the bottom layer as the sublock for this layer
- return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1)
+ # create the bottom layer as the subblock for this layer
+ return self._get_bottom_layer(channels[0], channels[1], strides[0])
encdec = _create_block(self.channels, self.strides)
self.model = nn.Sequential(head, encdec, reduce_channels)
- def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module:
+ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module:
return AttentionLayer(
spatial_dims=self.dimensions,
in_channels=in_channels,
@@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l
strides=strides,
dropout=self.dropout,
),
+ up_kernel_size=self.up_kernel_size,
+ strides=strides,
dropout=self.dropout,
)
diff --git a/monai/networks/nets/autoencoder.py b/monai/networks/nets/autoencoder.py
index 75edde70ebb..a88d1861ad6 100644
--- a/monai/networks/nets/autoencoder.py
+++ b/monai/networks/nets/autoencoder.py
@@ -16,7 +16,6 @@
from monai.networks.blocks import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
-from monai.utils import deprecated_arg
__all__ = ["AutoEncoder"]
@@ -57,9 +56,6 @@ class AutoEncoder(nn.Module):
According to `Performance Tuning Guide `_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
-
Examples::
from monai.networks.nets import AutoEncoder
@@ -88,9 +84,6 @@ class AutoEncoder(nn.Module):
"""
- @deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def __init__(
self,
spatial_dims: int,
@@ -108,11 +101,10 @@ def __init__(
norm: Union[Tuple, str] = Norm.INSTANCE,
dropout: Optional[Union[Tuple, str, float]] = None,
bias: bool = True,
- dimensions: Optional[int] = None,
) -> None:
super().__init__()
- self.dimensions = spatial_dims if dimensions is None else dimensions
+ self.dimensions = spatial_dims
self.in_channels = in_channels
self.out_channels = out_channels
self.channels = list(channels)
@@ -239,6 +231,7 @@ def _get_encode_layer(self, in_channels: int, out_channels: int, strides: int, i
bias=self.bias,
last_conv_only=is_last,
)
+ return mod
mod = Convolution(
spatial_dims=self.dimensions,
in_channels=in_channels,
diff --git a/monai/networks/nets/basic_unet.py b/monai/networks/nets/basic_unet.py
index 1e468465767..6fe77038fe9 100644
--- a/monai/networks/nets/basic_unet.py
+++ b/monai/networks/nets/basic_unet.py
@@ -24,7 +24,6 @@
class TwoConv(nn.Sequential):
"""two convolutions."""
- @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
def __init__(
self,
spatial_dims: int,
@@ -34,7 +33,6 @@ def __init__(
norm: Union[str, tuple],
bias: bool,
dropout: Union[float, tuple] = 0.0,
- dim: Optional[int] = None,
):
"""
Args:
@@ -46,13 +44,9 @@ def __init__(
bias: whether to have a bias term in convolution blocks.
dropout: dropout ratio. Defaults to no dropout.
- .. deprecated:: 0.6.0
- ``dim`` is deprecated, use ``spatial_dims`` instead.
"""
super().__init__()
- if dim is not None:
- spatial_dims = dim
conv_0 = Convolution(spatial_dims, in_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1)
conv_1 = Convolution(
spatial_dims, out_chns, out_chns, act=act, norm=norm, dropout=dropout, bias=bias, padding=1
@@ -64,7 +58,6 @@ def __init__(
class Down(nn.Sequential):
"""maxpooling downsampling and two convolutions."""
- @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
def __init__(
self,
spatial_dims: int,
@@ -74,7 +67,6 @@ def __init__(
norm: Union[str, tuple],
bias: bool,
dropout: Union[float, tuple] = 0.0,
- dim: Optional[int] = None,
):
"""
Args:
@@ -86,12 +78,8 @@ def __init__(
bias: whether to have a bias term in convolution blocks.
dropout: dropout ratio. Defaults to no dropout.
- .. deprecated:: 0.6.0
- ``dim`` is deprecated, use ``spatial_dims`` instead.
"""
super().__init__()
- if dim is not None:
- spatial_dims = dim
max_pooling = Pool["MAX", spatial_dims](kernel_size=2)
convs = TwoConv(spatial_dims, in_chns, out_chns, act, norm, bias, dropout)
self.add_module("max_pooling", max_pooling)
@@ -101,7 +89,6 @@ def __init__(
class UpCat(nn.Module):
"""upsampling, concatenation with the encoder feature map, two convolutions"""
- @deprecated_arg(name="dim", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead.")
def __init__(
self,
spatial_dims: int,
@@ -117,13 +104,13 @@ def __init__(
interp_mode: str = "linear",
align_corners: Optional[bool] = True,
halves: bool = True,
- dim: Optional[int] = None,
+ is_pad: bool = True,
):
"""
Args:
spatial_dims: number of spatial dimensions.
in_chns: number of input channels to be upsampled.
- cat_chns: number of channels from the decoder.
+ cat_chns: number of channels from the encoder.
out_chns: number of output channels.
act: activation type and arguments.
norm: feature normalization type and arguments.
@@ -139,13 +126,10 @@ def __init__(
Only used in the "nontrainable" mode.
halves: whether to halve the number of channels during upsampling.
This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
+ is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.
- .. deprecated:: 0.6.0
- ``dim`` is deprecated, use ``spatial_dims`` instead.
"""
super().__init__()
- if dim is not None:
- spatial_dims = dim
if upsample == "nontrainable" and pre_conv is None:
up_chns = in_chns
else:
@@ -161,6 +145,7 @@ def __init__(
align_corners=align_corners,
)
self.convs = TwoConv(spatial_dims, cat_chns + up_chns, out_chns, act, norm, bias, dropout)
+ self.is_pad = is_pad
def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
"""
@@ -172,13 +157,14 @@ def forward(self, x: torch.Tensor, x_e: Optional[torch.Tensor]):
x_0 = self.upsample(x)
if x_e is not None:
- # handling spatial shapes due to the 2x maxpooling with odd edge lengths.
- dimensions = len(x.shape) - 2
- sp = [0] * (dimensions * 2)
- for i in range(dimensions):
- if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
- sp[i * 2 + 1] = 1
- x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
+ if self.is_pad:
+ # handling spatial shapes due to the 2x maxpooling with odd edge lengths.
+ dimensions = len(x.shape) - 2
+ sp = [0] * (dimensions * 2)
+ for i in range(dimensions):
+ if x_e.shape[-i - 1] != x_0.shape[-i - 1]:
+ sp[i * 2 + 1] = 1
+ x_0 = torch.nn.functional.pad(x_0, sp, "replicate")
x = self.convs(torch.cat([x_e, x_0], dim=1)) # input channels: (cat_chns + up_chns)
else:
x = self.convs(x_0)
@@ -254,7 +240,6 @@ def __init__(
super().__init__()
if dimensions is not None:
spatial_dims = dimensions
-
fea = ensure_tuple_rep(features, 6)
print(f"BasicUNet features: {fea}.")
@@ -275,13 +260,13 @@ def forward(self, x: torch.Tensor):
"""
Args:
x: input should have spatially N dimensions
- ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.
+ ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`.
It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
even edge lengths.
Returns:
A torch Tensor of "raw" predictions in shape
- ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
+ ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
"""
x0 = self.conv_0(x)
diff --git a/monai/networks/nets/basic_unetplusplus.py b/monai/networks/nets/basic_unetplusplus.py
new file mode 100644
index 00000000000..4f7d319aaab
--- /dev/null
+++ b/monai/networks/nets/basic_unetplusplus.py
@@ -0,0 +1,171 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Sequence, Union
+
+import torch
+import torch.nn as nn
+
+from monai.networks.layers.factories import Conv
+from monai.networks.nets.basic_unet import Down, TwoConv, UpCat
+from monai.utils import ensure_tuple_rep
+
+__all__ = ["BasicUnetPlusPlus", "BasicunetPlusPlus", "basicunetplusplus", "BasicUNetPlusPlus"]
+
+
+class BasicUNetPlusPlus(nn.Module):
+ def __init__(
+ self,
+ spatial_dims: int = 3,
+ in_channels: int = 1,
+ out_channels: int = 2,
+ features: Sequence[int] = (32, 32, 64, 128, 256, 32),
+ deep_supervision: bool = False,
+ act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}),
+ norm: Union[str, tuple] = ("instance", {"affine": True}),
+ bias: bool = True,
+ dropout: Union[float, tuple] = 0.0,
+ upsample: str = "deconv",
+ ):
+ """
+ A UNet++ implementation with 1D/2D/3D supports.
+
+ Based on:
+
+ Zhou et al. "UNet++: A Nested U-Net Architecture for Medical Image
+ Segmentation". 4th Deep Learning in Medical Image Analysis (DLMIA)
+ Workshop, DOI: https://doi.org/10.48550/arXiv.1807.10165
+
+
+ Args:
+ spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
+ in_channels: number of input channels. Defaults to 1.
+ out_channels: number of output channels. Defaults to 2.
+ features: six integers as numbers of features.
+ Defaults to ``(32, 32, 64, 128, 256, 32)``,
+
+ - the first five values correspond to the five-level encoder feature sizes.
+ - the last value corresponds to the feature size after the last upsampling.
+
+ deep_supervision: whether to prune the network at inference time. Defaults to False. If true, returns a list,
+ whose elements correspond to outputs at different nodes.
+ act: activation type and arguments. Defaults to LeakyReLU.
+ norm: feature normalization type and arguments. Defaults to instance norm.
+ bias: whether to have a bias term in convolution blocks. Defaults to True.
+ According to `Performance Tuning Guide `_,
+ if a conv layer is directly followed by a batch norm layer, bias should be False.
+ dropout: dropout ratio. Defaults to no dropout.
+ upsample: upsampling mode, available options are
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
+
+ Examples::
+
+ # for spatial 2D
+ >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))
+
+ # for spatial 2D, with deep supervision enabled
+ >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), deep_supervision=True)
+
+ # for spatial 2D, with group norm
+ >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))
+
+ # for spatial 3D
+ >>> net = BasicUNetPlusPlus(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))
+
+ See Also
+ - :py:class:`monai.networks.nets.BasicUNet`
+ - :py:class:`monai.networks.nets.DynUNet`
+ - :py:class:`monai.networks.nets.UNet`
+
+ """
+ super().__init__()
+
+ self.deep_supervision = deep_supervision
+
+ fea = ensure_tuple_rep(features, 6)
+ print(f"BasicUNetPlusPlus features: {fea}.")
+
+ self.conv_0_0 = TwoConv(spatial_dims, in_channels, fea[0], act, norm, bias, dropout)
+ self.conv_1_0 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout)
+ self.conv_2_0 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout)
+ self.conv_3_0 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout)
+ self.conv_4_0 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout)
+
+ self.upcat_0_1 = UpCat(spatial_dims, fea[1], fea[0], fea[0], act, norm, bias, dropout, upsample, halves=False)
+ self.upcat_1_1 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample)
+ self.upcat_2_1 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample)
+ self.upcat_3_1 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample)
+
+ self.upcat_0_2 = UpCat(
+ spatial_dims, fea[1], fea[0] * 2, fea[0], act, norm, bias, dropout, upsample, halves=False
+ )
+ self.upcat_1_2 = UpCat(spatial_dims, fea[2], fea[1] * 2, fea[1], act, norm, bias, dropout, upsample)
+ self.upcat_2_2 = UpCat(spatial_dims, fea[3], fea[2] * 2, fea[2], act, norm, bias, dropout, upsample)
+
+ self.upcat_0_3 = UpCat(
+ spatial_dims, fea[1], fea[0] * 3, fea[0], act, norm, bias, dropout, upsample, halves=False
+ )
+ self.upcat_1_3 = UpCat(spatial_dims, fea[2], fea[1] * 3, fea[1], act, norm, bias, dropout, upsample)
+
+ self.upcat_0_4 = UpCat(
+ spatial_dims, fea[1], fea[0] * 4, fea[5], act, norm, bias, dropout, upsample, halves=False
+ )
+
+ self.final_conv_0_1 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1)
+ self.final_conv_0_2 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1)
+ self.final_conv_0_3 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1)
+ self.final_conv_0_4 = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1)
+
+ def forward(self, x: torch.Tensor):
+ """
+ Args:
+ x: input should have spatially N dimensions
+ ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `dimensions`.
+ It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
+ even edge lengths.
+
+ Returns:
+ A torch Tensor of "raw" predictions in shape
+ ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
+ """
+ x_0_0 = self.conv_0_0(x)
+ x_1_0 = self.conv_1_0(x_0_0)
+ x_0_1 = self.upcat_0_1(x_1_0, x_0_0)
+
+ x_2_0 = self.conv_2_0(x_1_0)
+ x_1_1 = self.upcat_1_1(x_2_0, x_1_0)
+ x_0_2 = self.upcat_0_2(x_1_1, torch.cat([x_0_0, x_0_1], dim=1))
+
+ x_3_0 = self.conv_3_0(x_2_0)
+ x_2_1 = self.upcat_2_1(x_3_0, x_2_0)
+ x_1_2 = self.upcat_1_2(x_2_1, torch.cat([x_1_0, x_1_1], dim=1))
+ x_0_3 = self.upcat_0_3(x_1_2, torch.cat([x_0_0, x_0_1, x_0_2], dim=1))
+
+ x_4_0 = self.conv_4_0(x_3_0)
+ x_3_1 = self.upcat_3_1(x_4_0, x_3_0)
+ x_2_2 = self.upcat_2_2(x_3_1, torch.cat([x_2_0, x_2_1], dim=1))
+ x_1_3 = self.upcat_1_3(x_2_2, torch.cat([x_1_0, x_1_1, x_1_2], dim=1))
+ x_0_4 = self.upcat_0_4(x_1_3, torch.cat([x_0_0, x_0_1, x_0_2, x_0_3], dim=1))
+
+ output_0_1 = self.final_conv_0_1(x_0_1)
+ output_0_2 = self.final_conv_0_2(x_0_2)
+ output_0_3 = self.final_conv_0_3(x_0_3)
+ output_0_4 = self.final_conv_0_4(x_0_4)
+
+ if self.deep_supervision:
+ output = [output_0_1, output_0_2, output_0_3, output_0_4]
+ else:
+ output = [output_0_4]
+
+ return output
+
+
+BasicUnetPlusPlus = BasicunetPlusPlus = basicunetplusplus = BasicUNetPlusPlus
diff --git a/monai/networks/nets/densenet.py b/monai/networks/nets/densenet.py
index 52bd2fa9941..2f02ecf395e 100644
--- a/monai/networks/nets/densenet.py
+++ b/monai/networks/nets/densenet.py
@@ -148,7 +148,7 @@ class DenseNet(nn.Module):
"""
Densenet based on: `Densely Connected Convolutional Networks `_.
Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
- This network is non-determistic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
+ This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
for more details:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
diff --git a/monai/networks/nets/dints.py b/monai/networks/nets/dints.py
index b7f3921a477..334d0abf0d3 100644
--- a/monai/networks/nets/dints.py
+++ b/monai/networks/nets/dints.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import warnings
from typing import List, Optional, Tuple, Union
diff --git a/monai/networks/nets/dynunet.py b/monai/networks/nets/dynunet.py
index 053ab255b84..ad7251241be 100644
--- a/monai/networks/nets/dynunet.py
+++ b/monai/networks/nets/dynunet.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import List, Optional, Sequence, Tuple, Union
import torch
diff --git a/monai/networks/nets/efficientnet.py b/monai/networks/nets/efficientnet.py
index fa5efbc4ef2..59e66f0713c 100644
--- a/monai/networks/nets/efficientnet.py
+++ b/monai/networks/nets/efficientnet.py
@@ -13,12 +13,13 @@
import operator
import re
from functools import reduce
-from typing import List, NamedTuple, Optional, Tuple, Type, Union
+from typing import Dict, List, NamedTuple, Optional, Tuple, Type, Union
import torch
from torch import nn
from torch.utils import model_zoo
+from monai.networks.blocks import BaseEncoder
from monai.networks.layers.factories import Act, Conv, Pad, Pool
from monai.networks.layers.utils import get_norm_layer
from monai.utils.module import look_up_option
@@ -30,6 +31,7 @@
"drop_connect",
"EfficientNetBNFeatures",
"BlockArgs",
+ "EfficientNetEncoder",
]
efficientnet_params = {
@@ -528,11 +530,8 @@ def __init__(
# check if model_name is valid model
if model_name not in efficientnet_params.keys():
- raise ValueError(
- "invalid model_name {} found, must be one of {} ".format(
- model_name, ", ".join(efficientnet_params.keys())
- )
- )
+ model_name_string = ", ".join(efficientnet_params.keys())
+ raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
# get network parameters
weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]
@@ -588,11 +587,8 @@ def __init__(
# check if model_name is valid model
if model_name not in efficientnet_params.keys():
- raise ValueError(
- "invalid model_name {} found, must be one of {} ".format(
- model_name, ", ".join(efficientnet_params.keys())
- )
- )
+ model_name_string = ", ".join(efficientnet_params.keys())
+ raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
# get network parameters
weight_coeff, depth_coeff, image_size, dropout_rate, dropconnect_rate = efficientnet_params[model_name]
@@ -638,6 +634,80 @@ def forward(self, inputs: torch.Tensor):
return features
+class EfficientNetEncoder(EfficientNetBNFeatures, BaseEncoder):
+ """
+ Wrap the original efficientnet to an encoder for flexible-unet.
+ """
+
+ backbone_names = [
+ "efficientnet-b0",
+ "efficientnet-b1",
+ "efficientnet-b2",
+ "efficientnet-b3",
+ "efficientnet-b4",
+ "efficientnet-b5",
+ "efficientnet-b6",
+ "efficientnet-b7",
+ "efficientnet-b8",
+ "efficientnet-l2",
+ ]
+
+ @classmethod
+ def get_encoder_parameters(cls) -> List[Dict]:
+ """
+ Get the initialization parameter for efficientnet backbones.
+ """
+ parameter_list = []
+ for backbone_name in cls.backbone_names:
+ parameter_list.append(
+ {
+ "model_name": backbone_name,
+ "pretrained": True,
+ "progress": True,
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "num_classes": 1000,
+ "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
+ "adv_prop": "ap" in backbone_name,
+ }
+ )
+ return parameter_list
+
+ @classmethod
+ def num_channels_per_output(cls) -> List[Tuple[int, ...]]:
+ """
+ Get number of efficientnet backbone output feature maps' channel.
+ """
+ return [
+ (16, 24, 40, 112, 320),
+ (16, 24, 40, 112, 320),
+ (16, 24, 48, 120, 352),
+ (24, 32, 48, 136, 384),
+ (24, 32, 56, 160, 448),
+ (24, 40, 64, 176, 512),
+ (32, 40, 72, 200, 576),
+ (32, 48, 80, 224, 640),
+ (32, 56, 88, 248, 704),
+ (72, 104, 176, 480, 1376),
+ ]
+
+ @classmethod
+ def num_outputs(cls) -> List[int]:
+ """
+ Get number of efficientnet backbone output feature maps.
+ Since every backbone contains the same 5 output feature maps,
+ the number list should be `[5] * 10`.
+ """
+ return [5] * 10
+
+ @classmethod
+ def get_encoder_names(cls) -> List[str]:
+ """
+ Get names of efficient backbone.
+ """
+ return cls.backbone_names
+
+
def get_efficientnet_image_size(model_name: str) -> int:
"""
Get the input image size for a given efficientnet model.
@@ -651,9 +721,8 @@ def get_efficientnet_image_size(model_name: str) -> int:
"""
# check if model_name is valid model
if model_name not in efficientnet_params.keys():
- raise ValueError(
- "invalid model_name {} found, must be one of {} ".format(model_name, ", ".join(efficientnet_params.keys()))
- )
+ model_name_string = ", ".join(efficientnet_params.keys())
+ raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
# return input image size (all dims equal so only need to return for one dim)
_, _, res, _, _ = efficientnet_params[model_name]
@@ -927,15 +996,10 @@ def to_string(self):
A string notation of BlockArgs object arguments.
Example: "r1_k3_s11_e1_i32_o16_se0.25_noskip".
"""
- string = "r{}_k{}_s{}{}_e{}_i{}_o{}_se{}".format(
- self.num_repeat,
- self.kernel_size,
- self.stride,
- self.stride,
- self.expand_ratio,
- self.input_filters,
- self.output_filters,
- self.se_ratio,
+ string = (
+ f"r{self.num_repeat}_k{self.kernel_size}_s{self.stride}{self.stride}"
+ f"_e{self.expand_ratio}_i{self.input_filters}_o{self.output_filters}"
+ f"_se{self.se_ratio}"
)
if not self.id_skip:
diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py
new file mode 100644
index 00000000000..1bd4ac7c9c9
--- /dev/null
+++ b/monai/networks/nets/flexible_unet.py
@@ -0,0 +1,343 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from pydoc import locate
+from typing import List, Optional, Sequence, Tuple, Type, Union
+
+import torch
+from torch import nn
+
+from monai.networks.blocks import BaseEncoder, UpSample
+from monai.networks.layers.factories import Conv
+from monai.networks.layers.utils import get_act_layer
+from monai.networks.nets import EfficientNetEncoder
+from monai.networks.nets.basic_unet import UpCat
+from monai.utils import InterpolateMode, optional_import
+
+__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
+
+
+class FlexUNetEncoderRegister:
+ """
+ A register to regist backbones for the flexible unet. All backbones can be found in
+ register_dict. Please notice each output of backbone must be 2x downsample in spatial
+ dimension of last output. For example, if given a 512x256 2D image and a backbone with
+ 4 outputs. Then spatial size of each encoder output should be 256x128, 128x64, 64x32
+ and 32x16.
+ """
+
+ def __init__(self):
+ self.register_dict = {}
+
+ def regist_class(self, name: Union[Type, str]):
+ """
+ Regist a given class to the encoder dict. Please notice that input class must be a
+ subclass of BaseEncoder.
+ """
+ if isinstance(name, str):
+ tmp_name, has_built_in = optional_import("monai.networks.nets", name=f"{name}") # search built-in
+ if not has_built_in:
+ tmp_name = locate(f"{name}") # search dotted path
+ name = tmp_name
+ if not isinstance(name, type):
+ raise ValueError(f"Cannot find {name} class.")
+
+ if not issubclass(name, BaseEncoder):
+ warnings.warn(
+ f"{name} would better be derived from monai.networks.blocks.BaseEncoder "
+ "or implement all interfaces specified by it."
+ )
+
+ name_string_list = name.get_encoder_names()
+ feature_number_list = name.num_outputs()
+ feature_channel_list = name.num_channels_per_output()
+ parameter_list = name.get_encoder_parameters()
+
+ assert len(name_string_list) == len(feature_number_list) == len(feature_channel_list) == len(parameter_list)
+ for cnt, name_string in enumerate(name_string_list):
+ cur_dict = {
+ "type": name,
+ "feature_number": feature_number_list[cnt],
+ "feature_channel": feature_channel_list[cnt],
+ "parameter": parameter_list[cnt],
+ }
+ self.register_dict[name_string] = cur_dict
+
+
+FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
+FLEXUNET_BACKBONE.regist_class(EfficientNetEncoder)
+
+
+class UNetDecoder(nn.Module):
+ """
+ UNet Decoder.
+ This class refers to `segmentation_models.pytorch
+ `_.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ encoder_channels: number of output channels for all feature maps in encoder.
+ `len(encoder_channels)` should be no less than 2.
+ decoder_channels: number of output channels for all feature maps in decoder.
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
+ act: activation type and arguments.
+ norm: feature normalization type and arguments.
+ dropout: dropout ratio.
+ bias: whether to have a bias term in convolution blocks in this decoder.
+ upsample: upsampling mode, available options are
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
+ pre_conv: a conv block applied before upsampling.
+ Only used in the "nontrainable" or "pixelshuffle" mode.
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
+ Only used in the "nontrainable" mode.
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
+ Only used in the "nontrainable" mode.
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ encoder_channels: Sequence[int],
+ decoder_channels: Sequence[int],
+ act: Union[str, tuple],
+ norm: Union[str, tuple],
+ dropout: Union[float, tuple],
+ bias: bool,
+ upsample: str,
+ pre_conv: Optional[str],
+ interp_mode: str,
+ align_corners: Optional[bool],
+ is_pad: bool,
+ ):
+
+ super().__init__()
+ if len(encoder_channels) < 2:
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
+ if len(decoder_channels) != len(encoder_channels) - 1:
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
+
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
+ halves = [True] * (len(skip_channels) - 1)
+ halves.append(False)
+ blocks = []
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
+ blocks.append(
+ UpCat(
+ spatial_dims=spatial_dims,
+ in_chns=in_chn,
+ cat_chns=skip_chn,
+ out_chns=out_chn,
+ act=act,
+ norm=norm,
+ dropout=dropout,
+ bias=bias,
+ upsample=upsample,
+ pre_conv=pre_conv,
+ interp_mode=interp_mode,
+ align_corners=align_corners,
+ halves=halve,
+ is_pad=is_pad,
+ )
+ )
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
+ skips = features[:-1][::-1]
+ features = features[1:][::-1]
+
+ x = features[0]
+ for i, block in enumerate(self.blocks):
+ if i < skip_connect:
+ skip = skips[i]
+ else:
+ skip = None
+ x = block(x, skip)
+
+ return x
+
+
+class SegmentationHead(nn.Sequential):
+ """
+ Segmentation head.
+ This class refers to `segmentation_models.pytorch
+ `_.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels for the block.
+ out_channels: number of output channels for the block.
+ kernel_size: kernel size for the conv layer.
+ act: activation type and arguments.
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ act: Optional[Union[Tuple, str]] = None,
+ scale_factor: float = 1.0,
+ ):
+
+ conv_layer = Conv[Conv.CONV, spatial_dims](
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
+ )
+ up_layer: nn.Module = nn.Identity()
+ if scale_factor > 1.0:
+ up_layer = UpSample(
+ spatial_dims=spatial_dims,
+ scale_factor=scale_factor,
+ mode="nontrainable",
+ pre_conv=None,
+ interp_mode=InterpolateMode.LINEAR,
+ )
+ if act is not None:
+ act_layer = get_act_layer(act)
+ else:
+ act_layer = nn.Identity()
+ super().__init__(conv_layer, up_layer, act_layer)
+
+
+class FlexibleUNet(nn.Module):
+ """
+ A flexible implementation of UNet-like encoder-decoder architecture.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ backbone: str,
+ pretrained: bool = False,
+ decoder_channels: Tuple = (256, 128, 64, 32, 16),
+ spatial_dims: int = 2,
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
+ dropout: Union[float, tuple] = 0.0,
+ decoder_bias: bool = False,
+ upsample: str = "nontrainable",
+ interp_mode: str = "nearest",
+ is_pad: bool = True,
+ ) -> None:
+ """
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
+ and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
+ is False.
+ Please notice each output of backbone must be 2x downsample in spatial dimension
+ of last output. For example, if given a 512x256 2D image and a backbone with 4 outputs.
+ Spatial size of each encoder output should be 256x128, 128x64, 64x32 and 32x16.
+
+ Args:
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ backbone: name of backbones to initialize, only support efficientnet right now,
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
+ pretrained: whether to initialize pretrained ImageNet weights, only available
+ for spatial_dims=2 and batch norm is used, default to False.
+ decoder_channels: number of output channels for all feature maps in decoder.
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
+ to (256, 128, 64, 32, 16).
+ spatial_dims: number of spatial dimensions, default to 2.
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
+ "momentum": 0.1}).
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
+ dropout: dropout ratio, default to 0.0.
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
+ ``"nontrainable"``.
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
+ Only used in the "nontrainable" mode.
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
+ """
+ super().__init__()
+
+ if backbone not in FLEXUNET_BACKBONE.register_dict:
+ raise ValueError(
+ f"invalid model_name {backbone} found, must be one of {FLEXUNET_BACKBONE.register_dict.keys()}."
+ )
+
+ if spatial_dims not in (2, 3):
+ raise ValueError("spatial_dims can only be 2 or 3.")
+
+ encoder = FLEXUNET_BACKBONE.register_dict[backbone]
+ self.backbone = backbone
+ self.spatial_dims = spatial_dims
+ encoder_parameters = encoder["parameter"]
+ if not (
+ ("spatial_dims" in encoder_parameters)
+ and ("in_channels" in encoder_parameters)
+ and ("pretrained" in encoder_parameters)
+ ):
+ raise ValueError("The backbone init method must have spatial_dims, in_channels and pretrained parameters.")
+ encoder_feature_num = encoder["feature_number"]
+ if encoder_feature_num > 5:
+ raise ValueError("Flexible unet can only accept no more than 5 encoder feature maps.")
+
+ decoder_channels = decoder_channels[:encoder_feature_num]
+ self.skip_connect = encoder_feature_num - 1
+ encoder_parameters.update({"spatial_dims": spatial_dims, "in_channels": in_channels, "pretrained": pretrained})
+ encoder_channels = tuple([in_channels] + list(encoder["feature_channel"]))
+ encoder_type = encoder["type"]
+ self.encoder = encoder_type(**encoder_parameters)
+
+ self.decoder = UNetDecoder(
+ spatial_dims=spatial_dims,
+ encoder_channels=encoder_channels,
+ decoder_channels=decoder_channels,
+ act=act,
+ norm=norm,
+ dropout=dropout,
+ bias=decoder_bias,
+ upsample=upsample,
+ interp_mode=interp_mode,
+ pre_conv=None,
+ align_corners=None,
+ is_pad=is_pad,
+ )
+ self.segmentation_head = SegmentationHead(
+ spatial_dims=spatial_dims,
+ in_channels=decoder_channels[-1],
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ )
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Do a typical encoder-decoder-header inference.
+
+ Args:
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
+ N is defined by `dimensions`.
+
+ Returns:
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
+
+ """
+ x = inputs
+ enc_out = self.encoder(x)
+ decoder_out = self.decoder(enc_out, self.skip_connect)
+ x_seg = self.segmentation_head(decoder_out)
+
+ return x_seg
+
+
+FlexUNet = FlexibleUNet
diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py
index 568d6658c5f..0834c5e94f5 100644
--- a/monai/networks/nets/hovernet.py
+++ b/monai/networks/nets/hovernet.py
@@ -15,7 +15,7 @@
# https://github.com/vqdang/hover_net/blob/master/LICENSE
# MIT License
-# Origial publication:
+# Original publication:
# @article{graham2019hover,
# title={Hover-net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images},
# author={Graham, Simon and Vu, Quoc Dang and Raza, Shan E Ahmed and Azam, Ayesha and Tsang, Yee Wah and Kwak,
@@ -25,22 +25,25 @@
# year={2019},
# publisher={Elsevier}
# }
-
# =========================================================================
+import os
+import re
+import warnings
from collections import OrderedDict
-from enum import Enum
-from typing import Callable, Dict, List, Sequence, Type, Union
+from typing import Callable, Dict, List, Optional, Sequence, Type, Union
import torch
import torch.nn as nn
+from monai.apps.utils import download_url
from monai.networks.blocks import UpSample
from monai.networks.layers.factories import Conv, Dropout
from monai.networks.layers.utils import get_act_layer, get_norm_layer
-from monai.utils import InterpolateMode, UpsampleMode, export
+from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode
+from monai.utils.module import export, look_up_option
-__all__ = ["HoverNet", "Hovernet", "HoVernet", "HoVerNet"]
+__all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"]
class _DenseLayerDecoder(nn.Module):
@@ -53,10 +56,10 @@ def __init__(
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
kernel_size: int = 3,
+ padding: int = 0,
) -> None:
"""
Args:
- spatial_dims: number of spatial dimensions of the input image.
num_features: number of internal channels used for the layer
in_channels: number of the input channels.
out_channels: number of the output channels.
@@ -64,6 +67,7 @@ def __init__(
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
+ padding: padding value for >1 convolutions.
"""
super().__init__()
@@ -78,7 +82,8 @@ def __init__(
self.layers.add_module("conv1/norm", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
self.layers.add_module("conv1/relu2", get_act_layer(name=act))
self.layers.add_module(
- "conv2", conv_type(num_features, out_channels, kernel_size=kernel_size, padding=0, groups=4, bias=False)
+ "conv2",
+ conv_type(num_features, out_channels, kernel_size=kernel_size, padding=padding, groups=4, bias=False),
)
if dropout_prob > 0:
@@ -87,7 +92,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.layers(x)
- if x1.shape != x.shape:
+ if x1.shape[-1] != x.shape[-1]:
trim = (x.shape[-1] - x1.shape[-1]) // 2
x = x[:, :, trim:-trim, trim:-trim]
@@ -107,10 +112,10 @@ def __init__(
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
kernel_size: int = 3,
+ same_padding: bool = False,
) -> None:
"""
Args:
- spatial_dims: number of spatial dimensions of the input image.
layers: number of layers in the block.
num_features: number of internal features used.
in_channels: number of the input channel.
@@ -119,17 +124,30 @@ def __init__(
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
+ same_padding: whether to do padding for >1 convolutions to ensure
+ the output size is the same as the input size.
"""
super().__init__()
conv_type: Callable = Conv[Conv.CONV, 2]
- self.add_module("conva", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, bias=False))
+ padding: int = kernel_size // 2 if same_padding else 0
+
+ self.add_module(
+ "conva", conv_type(in_channels, in_channels // 4, kernel_size=kernel_size, padding=padding, bias=False)
+ )
_in_channels = in_channels // 4
for i in range(layers):
layer = _DenseLayerDecoder(
- num_features, _in_channels, out_channels, dropout_prob, act=act, norm=norm, kernel_size=kernel_size
+ num_features,
+ _in_channels,
+ out_channels,
+ dropout_prob,
+ act=act,
+ norm=norm,
+ kernel_size=kernel_size,
+ padding=padding,
)
_in_channels += out_channels
self.add_module("denselayerdecoder%d" % (i + 1), layer)
@@ -175,22 +193,24 @@ def __init__(
dropout_type: Callable = Dropout[Dropout.DROPOUT, 2]
if not drop_first_norm_relu:
- self.layers.add_module("preact_norm", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
- self.layers.add_module("preact_relu", get_act_layer(name=act))
+ self.layers.add_module("preact/bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
+ self.layers.add_module("preact/relu", get_act_layer(name=act))
self.layers.add_module("conv1", conv_type(in_channels, num_features, kernel_size=1, padding=0, bias=False))
- self.layers.add_module("norm2", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
- self.layers.add_module("relu2", get_act_layer(name=act))
+ self.layers.add_module("conv1/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
+ self.layers.add_module("conv1/relu", get_act_layer(name=act))
if in_channels != 64 and drop_first_norm_relu:
self.layers.add_module(
"conv2", conv_type(num_features, num_features, kernel_size=kernel_size, stride=2, padding=2, bias=False)
)
else:
- self.layers.add_module("conv2", conv_type(num_features, num_features, kernel_size=1, padding=0, bias=False))
+ self.layers.add_module(
+ "conv2", conv_type(num_features, num_features, kernel_size=kernel_size, padding=1, bias=False)
+ )
- self.layers.add_module("norm3", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
- self.layers.add_module("relu3", get_act_layer(name=act))
+ self.layers.add_module("conv2/bn", get_norm_layer(name=norm, spatial_dims=2, channels=num_features))
+ self.layers.add_module("conv2/relu", get_act_layer(name=act))
self.layers.add_module("conv3", conv_type(num_features, out_channels, kernel_size=1, padding=0, bias=False))
if dropout_prob > 0:
@@ -209,7 +229,7 @@ def __init__(
"""
super().__init__()
- self.add_module("norm", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
+ self.add_module("bn", get_norm_layer(name=norm, spatial_dims=2, channels=in_channels))
self.add_module("relu", get_act_layer(name=act))
@@ -253,11 +273,11 @@ def __init__(
layer = _DenseLayer(
num_features, in_channels, out_channels, dropout_prob, act=act, norm=norm, drop_first_norm_relu=True
)
- self.layers.add_module("prim_denselayer%d" % (1), layer)
+ self.layers.add_module("denselayer_0", layer)
for i in range(1, layers):
layer = _DenseLayer(num_features, out_channels, out_channels, dropout_prob, act=act, norm=norm)
- self.layers.add_module("main_denselayer%d" % (i + 1), layer)
+ self.layers.add_module(f"denselayer_{i}", layer)
self.bna_block = _Transition(out_channels, act=act, norm=norm)
@@ -290,6 +310,7 @@ def __init__(
dropout_prob: float = 0.0,
out_channels: int = 2,
kernel_size: int = 3,
+ same_padding: bool = False,
) -> None:
"""
Args:
@@ -297,9 +318,10 @@ def __init__(
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
dropout_prob: dropout rate after each dense layer.
- num_features: number of internal features used.
out_channels: number of the output channel.
kernel_size: size of the kernel for >1 convolutions (dependent on mode)
+ same_padding: whether to do padding for >1 convolutions to ensure
+ the output size is the same as the input size.
"""
super().__init__()
conv_type: Callable = Conv[Conv.CONV, 2]
@@ -320,6 +342,7 @@ def __init__(
act=act,
norm=norm,
kernel_size=kernel_size,
+ same_padding=same_padding,
)
self.decoder_blocks.add_module(f"decoderblock{i + 1}", block)
_in_channels = 512
@@ -339,7 +362,7 @@ def __init__(
_seq_block = nn.Sequential(
OrderedDict(
[
- ("norm", get_norm_layer(name=norm, spatial_dims=2, channels=64)),
+ ("bn", get_norm_layer(name=norm, spatial_dims=2, channels=64)),
("relu", get_act_layer(name=act)),
("conv", conv_type(64, out_channels, kernel_size=1, stride=1)),
]
@@ -362,7 +385,8 @@ def forward(self, xin: torch.Tensor, short_cuts: List[torch.Tensor]) -> torch.Te
x = self.upsample(x)
block_number -= 1
trim = (short_cuts[block_number].shape[-1] - x.shape[-1]) // 2
- x += short_cuts[block_number][:, :, trim:-trim, trim:-trim]
+ if trim > 0:
+ x += short_cuts[block_number][:, :, trim:-trim, trim:-trim]
for block in self.output_features:
x = block(x)
@@ -371,49 +395,62 @@ def forward(self, xin: torch.Tensor, short_cuts: List[torch.Tensor]) -> torch.Te
@export("monai.networks.nets")
-class HoverNet(nn.Module):
- """HoVerNet
+class HoVerNet(nn.Module):
+ """HoVerNet model
References:
Graham, Simon et al. Hover-net: Simultaneous segmentation
and classification of nuclei in multi-tissue histology images,
Medical Image Analysis 2019
+ https://github.com/vqdang/hover_net
+
Args:
+ mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
+ a faster implementation (`HoVerNetMODE.FAST` or "fast"). Defaults to `HoVerNetMODE.FAST`.
in_channels: number of the input channel.
+ np_out_channels: number of the output channel of the nucleus prediction branch.
out_classes: number of the nuclear type classes.
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
+ decoder_padding: whether to do padding on convolution layers in the decoders. In the conic branch
+ of the referred repository, the architecture is changed to do padding on convolution layers in order to
+ get the same output size as the input, and this changed version is used on CoNIC challenge.
+ Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.
dropout_prob: dropout rate after each dense layer.
+ pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.
+ The weights should be ImageNet pretrained preact-resnet50 weights coming from the referred hover_net
+ repository, each user is responsible for checking the content of model/datasets and the applicable licenses
+ and determining if suitable for the intended use. please check the following link for more details:
+ https://github.com/vqdang/hover_net#data-format
"""
- class Mode(Enum):
- FAST: int = 0
- ORIGINAL: int = 1
-
- def _mode_to_int(self, mode) -> int:
-
- if mode == self.Mode.FAST:
- return 0
- else:
- return 1
+ Mode = HoVerNetMode
+ Branch = HoVerNetBranch
def __init__(
self,
- mode: Mode = Mode.FAST,
+ mode: Union[HoVerNetMode, str] = HoVerNetMode.FAST,
in_channels: int = 3,
+ np_out_channels: int = 2,
out_classes: int = 0,
act: Union[str, tuple] = ("relu", {"inplace": True}),
norm: Union[str, tuple] = "batch",
+ decoder_padding: bool = False,
dropout_prob: float = 0.0,
+ pretrained_url: Optional[str] = None,
) -> None:
super().__init__()
- self.mode: int = self._mode_to_int(mode)
+ if isinstance(mode, str):
+ mode = mode.upper()
+ self.mode = look_up_option(mode, HoVerNetMode)
- if mode not in [self.Mode.ORIGINAL, self.Mode.FAST]:
- raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
+ if self.mode == "ORIGINAL" and decoder_padding is True:
+ warnings.warn(
+ "'decoder_padding=True' only works when mode is 'FAST', otherwise the output size may not equal to the input."
+ )
if out_classes > 128:
raise ValueError("Number of nuclear types classes exceeds maximum (128)")
@@ -428,7 +465,7 @@ def __init__(
# number of layers in each pooling block.
_block_config: Sequence[int] = (3, 4, 6, 3)
- if mode == self.Mode.FAST:
+ if self.mode == HoVerNetMode.FAST:
_ksize = 3
_pad = 3
else:
@@ -437,15 +474,12 @@ def __init__(
conv_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]
- self.input_features = nn.Sequential(
+ self.conv0 = nn.Sequential(
OrderedDict(
[
- (
- "conv0",
- conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False),
- ),
- ("norm0", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)),
- ("relu0", get_act_layer(name=act)),
+ ("conv", conv_type(in_channels, _init_features, kernel_size=7, stride=1, padding=_pad, bias=False)),
+ ("bn", get_norm_layer(name=norm, spatial_dims=2, channels=_init_features)),
+ ("relu", get_act_layer(name=act)),
]
)
)
@@ -466,7 +500,7 @@ def __init__(
act=act,
norm=norm,
)
- self.res_blocks.add_module(f"residualblock{i + 1}", block)
+ self.res_blocks.add_module(f"d{i}", block)
_in_channels = _out_channels
_out_channels *= 2
@@ -482,12 +516,15 @@ def __init__(
)
# decode branches
- self.nucleus_prediction = _DecoderBranch(kernel_size=_ksize)
- self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize)
- self.type_prediction: _DecoderBranch = None # type: ignore
-
- if out_classes > 0:
- self.type_prediction = _DecoderBranch(out_channels=out_classes, kernel_size=_ksize)
+ self.nucleus_prediction = _DecoderBranch(
+ kernel_size=_ksize, same_padding=decoder_padding, out_channels=np_out_channels
+ )
+ self.horizontal_vertical = _DecoderBranch(kernel_size=_ksize, same_padding=decoder_padding)
+ self.type_prediction: Optional[_DecoderBranch] = (
+ _DecoderBranch(out_channels=out_classes, kernel_size=_ksize, same_padding=decoder_padding)
+ if out_classes > 0
+ else None
+ )
for m in self.modules():
if isinstance(m, conv_type):
@@ -496,17 +533,22 @@ def __init__(
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
+ if pretrained_url is not None:
+ _load_pretrained_encoder(self, pretrained_url)
+
+ def freeze_encoder(self):
+ self.res_blocks.requires_grad_(False)
+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
- if self.mode == 1:
+ if self.mode == HoVerNetMode.ORIGINAL.value:
if x.shape[-1] != 270 or x.shape[-2] != 270:
- raise ValueError("Input size should be 270 x 270 when using Mode.ORIGINAL")
+ raise ValueError("Input size should be 270 x 270 when using HoVerNetMode.ORIGINAL")
else:
if x.shape[-1] != 256 or x.shape[-2] != 256:
- raise ValueError("Input size should be 256 x 256 when using Mode.FAST")
+ raise ValueError("Input size should be 256 x 256 when using HoVerNetMode.FAST")
- x = x / 255.0 # to 0-1 range to match XY
- x = self.input_features(x)
+ x = self.conv0(x)
short_cuts = []
for i, block in enumerate(self.res_blocks):
@@ -518,15 +560,48 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x = self.bottleneck(x)
x = self.upsample(x)
- x_np = self.nucleus_prediction(x, short_cuts)
- x_hv = self.horizontal_vertical(x, short_cuts)
- tp = self.type_prediction
-
- if tp is not None:
- x_tp = self.type_prediction(x, short_cuts)
- return {"nucleus_prediction": x_np, "horizonal_vertical": x_hv, "type_prediction": x_tp}
-
- return {"nucleus_prediction": x_np, "horizonal_vertical": x_hv}
-
-
-Hovernet = HoVernet = HoVerNet = HoverNet
+ output = {
+ HoVerNetBranch.NP.value: self.nucleus_prediction(x, short_cuts),
+ HoVerNetBranch.HV.value: self.horizontal_vertical(x, short_cuts),
+ }
+ if self.type_prediction is not None:
+ output[HoVerNetBranch.NC.value] = self.type_prediction(x, short_cuts)
+
+ return output
+
+
+def _load_pretrained_encoder(model: nn.Module, model_url: str):
+
+ pattern_conv0 = re.compile(r"^(conv0\.\/)(.+)$")
+ pattern_block = re.compile(r"^(d\d+)\.(.+)$")
+ pattern_layer = re.compile(r"^(.+\.d\d+)\.units\.(\d+)(.+)$")
+ pattern_bna = re.compile(r"^(.+\.d\d+)\.blk_bna\.(.+)")
+ # download the pretrained weights into torch hub's default dir
+ weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
+ download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
+ state_dict = torch.load(weights_dir, map_location=None)["desc"]
+ for key in list(state_dict.keys()):
+ new_key = None
+ if pattern_conv0.match(key):
+ new_key = re.sub(pattern_conv0, r"conv0.conv\2", key)
+ elif pattern_block.match(key):
+ new_key = re.sub(pattern_block, r"res_blocks.\1.\2", key)
+ if pattern_layer.match(new_key):
+ new_key = re.sub(pattern_layer, r"\1.layers.denselayer_\2.layers\3", new_key)
+ elif pattern_bna.match(new_key):
+ new_key = re.sub(pattern_bna, r"\1.bna_block.\2", new_key)
+ if new_key:
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+ if "upsample2x" in key:
+ del state_dict[key]
+
+ model_dict = model.state_dict()
+ state_dict = {
+ k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
+ }
+ model_dict.update(state_dict)
+ model.load_state_dict(model_dict)
+
+
+Hovernet = HoVernet = HoverNet = HoVerNet
diff --git a/monai/networks/nets/netadapter.py b/monai/networks/nets/netadapter.py
index 425c1d58208..39112e4d549 100644
--- a/monai/networks/nets/netadapter.py
+++ b/monai/networks/nets/netadapter.py
@@ -14,14 +14,18 @@
import torch
from monai.networks.layers import Conv, get_pool_layer
-from monai.utils import deprecated_arg
+from monai.networks.utils import look_up_named_module, set_named_module
+from monai.utils import look_up_option, optional_import
+
+get_graph_node_names, _has_utils = optional_import("torchvision.models.feature_extraction", name="get_graph_node_names")
+create_feature_extractor, _ = optional_import("torchvision.models.feature_extraction", name="create_feature_extractor")
class NetAdapter(torch.nn.Module):
"""
Wrapper to replace the last layer of model by convolutional layer or FC layer.
- This module expects the output of `model layers[0: -2]` is a feature map with shape [B, C, spatial dims],
- then replace the model's last two layers with an optional `pooling` and a `conv` or `linear` layer.
+
+ See also: :py:class:`monai.networks.nets.TorchVisionFCModel`
Args:
model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model
@@ -31,19 +35,18 @@ class NetAdapter(torch.nn.Module):
dim: number of supported spatial dimensions in the specified model, depends on the model implementation.
default to 2 as most Torchvision models are for 2D image processing.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
- use_conv: whether use convolutional layer to replace the last layer, default to False.
+ use_conv: whether to use convolutional layer to replace the last layer, default to False.
pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
default to `("avg", {"kernel_size": 7, "stride": 1})`.
bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
default to True.
-
- .. deprecated:: 0.6.0
- ``n_classes`` is deprecated, use ``num_classes`` instead.
+ fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``"fc"``.
+ node_name: the corresponding feature extractor node name of `model`.
+ Defaults to "", the extractor is not in use.
"""
- @deprecated_arg("n_classes", since="0.6")
def __init__(
self,
model: torch.nn.Module,
@@ -53,51 +56,66 @@ def __init__(
use_conv: bool = False,
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
- n_classes: Optional[int] = None,
+ fc_name: str = "fc",
+ node_name: str = "",
):
super().__init__()
- # in case the new num_classes is default but you still call deprecated n_classes
- if n_classes is not None and num_classes == 1:
- num_classes = n_classes
layers = list(model.children())
- orig_fc = layers[-1]
+ orig_fc = look_up_named_module(fc_name, model)
+ if orig_fc is None:
+ orig_fc = layers[-1]
+ # guess the number of input channels of the last fully connected layer
in_channels_: int
-
if in_channels is None:
if not hasattr(orig_fc, "in_features"):
- raise ValueError("please specify the input channels of last layer with arg `in_channels`.")
- in_channels_ = orig_fc.in_features # type: ignore
+ raise ValueError("please specify input channels of the last fully connected layer with `in_channels`.")
+ in_channels_ = orig_fc.in_features
+
else:
in_channels_ = in_channels
- if pool is None:
- # remove the last layer
- self.features = torch.nn.Sequential(*layers[:-1])
+ # modify the input model, depending on whether to replace the last pooling layer ``pool``
+ if pool is None: # no modification of pooling
+ if node_name != "":
+ raise ValueError("`node_name` is not compatible with `pool=None`, please set `pool=''`.")
+ # we just drop the model's fully connected layer or set it to identity
+ if look_up_named_module(fc_name, model):
+ self.features = set_named_module(model, fc_name, torch.nn.Identity())
+ else:
+ self.features = torch.nn.Sequential(*layers[:-1]) # assuming FC is the last and model is sequential
self.pool = None
else:
- # remove the last 2 layers
- self.features = torch.nn.Sequential(*layers[:-2])
+ # user-specified new pooling layer, we drop both the pooling and FC layers from the model
+ if node_name and _has_utils:
+ node_name = look_up_option(node_name, get_graph_node_names(model)[0 if model.training else 1])
+ self.features = create_feature_extractor(model, [node_name])
+ else:
+ self.features = torch.nn.Sequential(*layers[:-2]) # assuming the last 2 layers are pooling&FC
self.pool = get_pool_layer(name=pool, spatial_dims=dim)
+ # create new fully connected layer or kernel size 1 convolutional layer
self.fc: Union[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv3d]
if use_conv:
- # add 1x1 conv (it behaves like a FC layer)
self.fc = Conv[Conv.CONV, dim](in_channels=in_channels_, out_channels=num_classes, kernel_size=1, bias=bias)
else:
- # remove the last Linear layer (fully connected)
- self.features = torch.nn.Sequential(*layers[:-1])
- # replace the out_features of FC layer
self.fc = torch.nn.Linear(in_features=in_channels_, out_features=num_classes, bias=bias)
self.use_conv = use_conv
+ self.dim = dim
+ self.node_name = node_name
def forward(self, x):
x = self.features(x)
+ if isinstance(x, tuple):
+ x = x[0] # it might be a namedtuple such as torchvision.model.InceptionOutputs
+ elif torch.jit.isinstance(x, Dict[str, torch.Tensor]):
+ x = x[self.node_name] # torchvision create_feature_extractor
if self.pool is not None:
x = self.pool(x)
-
if not self.use_conv:
x = torch.flatten(x, 1)
-
+ else: # user specified `use_conv` but the pooling layer removed the spatial dims
+ while len(x.shape) < self.dim + 2:
+ x = x[..., None]
x = self.fc(x)
return x
diff --git a/monai/networks/nets/regunet.py b/monai/networks/nets/regunet.py
index 6776c7ce9e6..db56b68b66d 100644
--- a/monai/networks/nets/regunet.py
+++ b/monai/networks/nets/regunet.py
@@ -167,8 +167,6 @@ def build_bottom_block(self, in_channels: int, out_channels: int):
)
def build_decode_layers(self):
- # decoding / up-sampling
- # [depth - 1, depth - 2, ..., min_extract_level]
self.decode_deconvs = nn.ModuleList(
[
self.build_up_sampling_block(in_channels=self.num_channels[d + 1], out_channels=self.num_channels[d])
@@ -221,9 +219,7 @@ def forward(self, x):
outs = [decoded]
- # [depth - 1, ..., min_extract_level]
for i, (decode_deconv, decode_conv) in enumerate(zip(self.decode_deconvs, self.decode_convs)):
- # [depth - 1, depth - 2, ..., min_extract_level]
decoded = decode_deconv(decoded)
if self.concat_skip:
decoded = torch.cat([decoded, skips[-i - 1]], dim=1)
diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py
index bf5486f06e4..e923c1bb7dc 100644
--- a/monai/networks/nets/resnet.py
+++ b/monai/networks/nets/resnet.py
@@ -10,7 +10,7 @@
# limitations under the License.
from functools import partial
-from typing import Any, Callable, List, Optional, Tuple, Type, Union
+from typing import Any, Callable, List, Tuple, Type, Union
import torch
import torch.nn as nn
@@ -20,9 +20,18 @@
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option
-__all__ = ["ResNet", "resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
-
-from monai.utils import deprecated_arg
+__all__ = [
+ "ResNet",
+ "ResNetBlock",
+ "ResNetBottleneck",
+ "resnet10",
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "resnet152",
+ "resnet200",
+]
def get_inplanes():
@@ -167,12 +176,8 @@ class ResNet(nn.Module):
num_classes: number of output (classifications).
feed_forward: whether to add the FC layer for the output, default to `True`.
- .. deprecated:: 0.6.0
- ``n_classes`` is deprecated, use ``num_classes`` instead.
-
"""
- @deprecated_arg("n_classes", since="0.6")
def __init__(
self,
block: Union[Type[Union[ResNetBlock, ResNetBottleneck]], str],
@@ -187,13 +192,9 @@ def __init__(
widen_factor: float = 1.0,
num_classes: int = 400,
feed_forward: bool = True,
- n_classes: Optional[int] = None,
) -> None:
super().__init__()
- # in case the new num_classes is default but you still call deprecated n_classes
- if n_classes is not None and num_classes == 400:
- num_classes = n_classes
if isinstance(block, str):
if block == "basic":
diff --git a/monai/networks/nets/segresnet.py b/monai/networks/nets/segresnet.py
index 299f1ca811a..cc908f16409 100644
--- a/monai/networks/nets/segresnet.py
+++ b/monai/networks/nets/segresnet.py
@@ -101,7 +101,7 @@ def __init__(
def _make_down_layers(self):
down_layers = nn.ModuleList()
blocks_down, spatial_dims, filters, norm = (self.blocks_down, self.spatial_dims, self.init_filters, self.norm)
- for i in range(len(blocks_down)):
+ for i, item in enumerate(blocks_down):
layer_in_channels = filters * 2**i
pre_conv = (
get_conv_layer(spatial_dims, layer_in_channels // 2, layer_in_channels, stride=2)
@@ -109,8 +109,7 @@ def _make_down_layers(self):
else nn.Identity()
)
down_layer = nn.Sequential(
- pre_conv,
- *[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(blocks_down[i])],
+ pre_conv, *[ResBlock(spatial_dims, layer_in_channels, norm=norm, act=self.act) for _ in range(item)]
)
down_layers.append(down_layer)
return down_layers
diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py
new file mode 100644
index 00000000000..a440e28ba70
--- /dev/null
+++ b/monai/networks/nets/segresnet_ds.py
@@ -0,0 +1,430 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks.upsample import UpSample
+from monai.networks.layers.factories import Act, Conv, Norm, split_args
+from monai.networks.layers.utils import get_act_layer, get_norm_layer
+from monai.utils import UpsampleMode, has_option
+
+__all__ = ["SegResNetDS"]
+
+
+def scales_for_resolution(resolution: Union[Tuple, List], n_stages: Optional[int] = None):
+ """
+ A helper function to compute a schedule of scale at different downsampling levels,
+ given the input resolution.
+
+ .. code-block:: python
+
+ scales_for_resolution(resolution=[1,1,5], n_stages=5)
+
+ Args:
+ resolution: input image resolution (in mm)
+ n_stages: optionally the number of stages of the network
+ """
+
+ ndim = len(resolution)
+ res = np.array(resolution)
+ if not all(res > 0):
+ raise ValueError("Resolution must be positive")
+
+ nl = np.floor(np.log2(np.max(res) / res)).astype(np.int32)
+ scales = [tuple(np.where(2**i >= 2**nl, 1, 2)) for i in range(max(nl))]
+ if n_stages and n_stages > max(nl):
+ scales = scales + [(2,) * ndim] * (n_stages - max(nl))
+ else:
+ scales = scales[:n_stages]
+ return scales
+
+
+def aniso_kernel(scale: Union[Tuple, List]):
+ """
+ A helper function to compute kernel_size, padding and stride for the given scale
+
+ Args:
+ scale: scale from a current scale level
+ """
+ kernel_size = [3 if scale[k] > 1 else 1 for k in range(len(scale))]
+ padding = [k // 2 for k in kernel_size]
+ return kernel_size, padding, scale
+
+
+class SegResBlock(nn.Module):
+ """
+ Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
+ `_.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ norm: Union[Tuple, str],
+ kernel_size: Union[Tuple, int] = 3,
+ act: Union[Tuple, str] = "relu",
+ ) -> None:
+ """
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2 or 3.
+ in_channels: number of input channels.
+ norm: feature normalization type and arguments.
+ kernel_size: convolution kernel size. Defaults to 3.
+ act: activation type and arguments. Defaults to ``RELU``.
+ """
+ super().__init__()
+
+ if isinstance(kernel_size, (tuple, list)):
+ padding = tuple(k // 2 for k in kernel_size)
+ else:
+ padding = kernel_size // 2 # type: ignore
+
+ self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
+ self.act1 = get_act_layer(act)
+ self.conv1 = Conv[Conv.CONV, spatial_dims](
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ bias=False,
+ )
+
+ self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
+ self.act2 = get_act_layer(act)
+ self.conv2 = Conv[Conv.CONV, spatial_dims](
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ bias=False,
+ )
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(self.act1(self.norm1(x)))
+ x = self.conv2(self.act2(self.norm2(x)))
+ x += identity
+ return x
+
+
+class SegResEncoder(nn.Module):
+ """
+ SegResEncoder based on the econder structure in `3D MRI brain tumor segmentation using autoencoder regularization
+ `_.
+
+ Args:
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
+ in_channels: number of input channels for the network. Defaults to 1.
+ out_channels: number of output channels for the network. Defaults to 2.
+ act: activation type and arguments. Defaults to ``RELU``.
+ norm: feature normalization type and arguments. Defaults to ``BATCH``.
+ blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
+ head_module: optional callable module to apply to the final features.
+ anisotropic_scales: optional list of scale for each scale level.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int = 3,
+ init_filters: int = 32,
+ in_channels: int = 1,
+ act: Union[Tuple, str] = "relu",
+ norm: Union[Tuple, str] = "batch",
+ blocks_down: tuple = (1, 2, 2, 4),
+ head_module: Optional[nn.Module] = None,
+ anisotropic_scales: Optional[Tuple] = None,
+ ):
+
+ super().__init__()
+
+ if spatial_dims not in (1, 2, 3):
+ raise ValueError("`spatial_dims` can only be 1, 2 or 3.")
+
+ # ensure normalization has affine trainable parameters (if not specified)
+ norm = split_args(norm)
+ if has_option(Norm[norm[0], spatial_dims], "affine"):
+ norm[1].setdefault("affine", True) # type: ignore
+
+ # ensure activation is inplace (if not specified)
+ act = split_args(act)
+ if has_option(Act[act[0]], "inplace"):
+ act[1].setdefault("inplace", True) # type: ignore
+
+ filters = init_filters # base number of features
+
+ kernel_size, padding, _ = aniso_kernel(anisotropic_scales[0]) if anisotropic_scales else (3, 1, 1)
+ self.conv_init = Conv[Conv.CONV, spatial_dims](
+ in_channels=in_channels,
+ out_channels=filters,
+ kernel_size=kernel_size,
+ padding=padding,
+ stride=1,
+ bias=False,
+ )
+ self.layers = nn.ModuleList()
+
+ for i in range(len(blocks_down)):
+ level = nn.ModuleDict()
+
+ kernel_size, padding, stride = aniso_kernel(anisotropic_scales[i]) if anisotropic_scales else (3, 1, 2)
+ blocks = [
+ SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)
+ for _ in range(blocks_down[i])
+ ]
+ level["blocks"] = nn.Sequential(*blocks)
+
+ if i < len(blocks_down) - 1:
+ level["downsample"] = Conv[Conv.CONV, spatial_dims](
+ in_channels=filters,
+ out_channels=2 * filters,
+ bias=False,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ level["downsample"] = nn.Identity()
+
+ self.layers.append(level)
+ filters *= 2
+
+ self.head_module = head_module
+ self.in_channels = in_channels
+ self.blocks_down = blocks_down
+ self.init_filters = init_filters
+ self.norm = norm
+ self.act = act
+ self.spatial_dims = spatial_dims
+
+ def _forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+
+ outputs = []
+ x = self.conv_init(x)
+
+ for level in self.layers:
+ x = level["blocks"](x)
+ outputs.append(x)
+ x = level["downsample"](x)
+
+ if self.head_module is not None:
+ outputs = self.head_module(outputs)
+
+ return outputs
+
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+ return self._forward(x)
+
+
+class SegResNetDS(nn.Module):
+ """
+ SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization
+ `_.
+ It is similar to https://docs.monai.io/en/stable/networks.html#segresnet, with several
+ improvements including deep supervision and non-isotropic kernel support.
+
+ Args:
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
+ in_channels: number of input channels for the network. Defaults to 1.
+ out_channels: number of output channels for the network. Defaults to 2.
+ act: activation type and arguments. Defaults to ``RELU``.
+ norm: feature normalization type and arguments. Defaults to ``BATCH``.
+ blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
+ blocks_up: number of upsample blocks (optional).
+ dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
+ At dsdepth==1,only a single output is returned.
+ preprocess: optional callable function to apply before the model's forward pass
+ resolution: optional input image resolution. When provided, the nework will first use non-isotropic kernels to bring
+ image spacing into an approximately isotropic space.
+ Otherwise, by default, the kernel size and downsampling is always isotropic.
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int = 3,
+ init_filters: int = 32,
+ in_channels: int = 1,
+ out_channels: int = 2,
+ act: Union[Tuple, str] = "relu",
+ norm: Union[Tuple, str] = "batch",
+ blocks_down: tuple = (1, 2, 2, 4),
+ blocks_up: Optional[Tuple] = None,
+ dsdepth: int = 1,
+ preprocess: Optional[Union[nn.Module, Callable]] = None,
+ upsample_mode: Union[UpsampleMode, str] = "deconv",
+ resolution: Optional[Tuple] = None,
+ ):
+
+ super().__init__()
+
+ if spatial_dims not in (1, 2, 3):
+ raise ValueError("`spatial_dims` can only be 1, 2 or 3.")
+
+ self.spatial_dims = spatial_dims
+ self.init_filters = init_filters
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.act = act
+ self.norm = norm
+ self.blocks_down = blocks_down
+ self.dsdepth = max(dsdepth, 1)
+ self.resolution = resolution
+ self.preprocess = preprocess
+
+ if resolution is not None:
+ if not isinstance(resolution, (list, tuple)):
+ raise TypeError("resolution must be a tuple")
+ elif not all(r > 0 for r in resolution):
+ raise ValueError("resolution must be positive")
+
+ # ensure normalization had affine trainable parameters (if not specified)
+ norm = split_args(norm)
+ if has_option(Norm[norm[0], spatial_dims], "affine"):
+ norm[1].setdefault("affine", True) # type: ignore
+
+ # ensure activation is inplace (if not specified)
+ act = split_args(act)
+ if has_option(Act[act[0]], "inplace"):
+ act[1].setdefault("inplace", True) # type: ignore
+
+ anisotropic_scales = None
+ if resolution:
+ anisotropic_scales = scales_for_resolution(resolution, n_stages=len(blocks_down))
+ self.anisotropic_scales = anisotropic_scales
+
+ self.encoder = SegResEncoder(
+ spatial_dims=spatial_dims,
+ init_filters=init_filters,
+ in_channels=in_channels,
+ act=act,
+ norm=norm,
+ blocks_down=blocks_down,
+ anisotropic_scales=anisotropic_scales,
+ )
+
+ n_up = len(blocks_down) - 1
+ if blocks_up is None:
+ blocks_up = (1,) * n_up # assume 1 upsample block per level
+ self.blocks_up = blocks_up
+
+ filters = init_filters * 2**n_up
+ self.up_layers = nn.ModuleList()
+
+ for i in range(n_up):
+
+ filters = filters // 2
+ kernel_size, _, stride = (
+ aniso_kernel(anisotropic_scales[len(blocks_up) - i - 1]) if anisotropic_scales else (3, 1, 2)
+ )
+
+ level = nn.ModuleDict()
+ level["upsample"] = UpSample(
+ mode=upsample_mode,
+ spatial_dims=spatial_dims,
+ in_channels=2 * filters,
+ out_channels=filters,
+ kernel_size=kernel_size,
+ scale_factor=stride,
+ bias=False,
+ align_corners=False,
+ )
+ blocks = [
+ SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act)
+ for _ in range(blocks_up[i])
+ ]
+ level["blocks"] = nn.Sequential(*blocks)
+
+ if len(blocks_up) - i <= dsdepth: # deep supervision heads
+ level["head"] = Conv[Conv.CONV, spatial_dims](
+ in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True
+ )
+ else:
+ level["head"] = nn.Identity()
+
+ self.up_layers.append(level)
+
+ if n_up == 0: # in a corner case of flat structure (no downsampling), attache a single head
+ level = nn.ModuleDict(
+ {
+ "upsample": nn.Identity(),
+ "blocks": nn.Identity(),
+ "head": Conv[Conv.CONV, spatial_dims](
+ in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True
+ ),
+ }
+ )
+ self.up_layers.append(level)
+
+ def shape_factor(self):
+ """
+ Calculate the factors (divisors) that the input image shape must be divisible by
+ """
+ if self.anisotropic_scales is None:
+ d = [2 ** (len(self.blocks_down) - 1)] * self.spatial_dims
+ else:
+ d = list(np.prod(np.array(self.anisotropic_scales[:-1]), axis=0))
+ return d
+
+ def is_valid_shape(self, x):
+ """
+ Calculate if the input shape is divisible by the minimum factors for the current nework configuration
+ """
+ a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())]
+ return all(a)
+
+ def _forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
+
+ if self.preprocess is not None:
+ x = self.preprocess(x)
+
+ if not self.is_valid_shape(x):
+ raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
+
+ x_down = self.encoder(x)
+
+ x_down.reverse()
+ x = x_down.pop(0)
+
+ if len(x_down) == 0:
+ x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
+
+ outputs: List[torch.Tensor] = []
+
+ i = 0
+ for level in self.up_layers:
+ x = level["upsample"](x)
+ x = x + x_down[i]
+ x = level["blocks"](x)
+
+ if len(self.up_layers) - i <= self.dsdepth:
+ outputs.append(level["head"](x))
+ i = i + 1
+
+ outputs.reverse()
+
+ # in eval() mode, always return a single final output
+ if not self.training or len(outputs) == 1:
+ return outputs[0]
+
+ # return a list of DS outputs
+ return outputs
+
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
+ return self._forward(x)
diff --git a/monai/networks/nets/senet.py b/monai/networks/nets/senet.py
index 8933cbe7e95..b4c024b1f2e 100644
--- a/monai/networks/nets/senet.py
+++ b/monai/networks/nets/senet.py
@@ -34,7 +34,6 @@
"SE_NET_MODELS",
]
-
SE_NET_MODELS = {
"senet154": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth",
"se_resnet50": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth",
diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py
index 5391e725daa..933ed06f7fa 100644
--- a/monai/networks/nets/swin_unetr.py
+++ b/monai/networks/nets/swin_unetr.py
@@ -503,7 +503,7 @@ def forward(self, x, mask):
else:
attn = self.softmax(attn)
- attn = self.attn_drop(attn)
+ attn = self.attn_drop(attn).to(v.dtype)
x = (attn @ v).transpose(1, 2).reshape(b, n, c)
x = self.proj(x)
x = self.proj_drop(x)
diff --git a/monai/networks/nets/torchvision_fc.py b/monai/networks/nets/torchvision_fc.py
index ddfee6e0416..85103a2c04f 100644
--- a/monai/networks/nets/torchvision_fc.py
+++ b/monai/networks/nets/torchvision_fc.py
@@ -12,37 +12,89 @@
from typing import Any, Dict, Optional, Tuple
from monai.networks.nets import NetAdapter
-from monai.utils import deprecated_arg, optional_import
+from monai.utils import optional_import
models, _ = optional_import("torchvision.models")
-
__all__ = ["TorchVisionFCModel"]
class TorchVisionFCModel(NetAdapter):
"""
- Customize the fully connected layer of TorchVision model or replace it by convolutional layer.
+ Customize the fully connected layer of (pretrained) TorchVision model or replace it by convolutional layer.
+
+ This class supports two primary use cases:
+
+ - use ``pool=None`` to indicate no modification in the pooling layers. It should be used with ``fc_name``
+ to locate the target FC layer to modify:
+ In this case, the class will load a torchvision classification model,
+ replace the last fully connected (FC) layer with a new FC layer with ``num_classes`` outputs,
+ example input arguments: ``use_conv=False, pool=None, fc_name="heads.head"``.
+ The ``heads.head`` specifies the target FC of the input model, could be found by ``model.named_modules()``,
+ for example::
+
+ from torchvision.models import vit_b_16
+ print([name[0] for name in vit_b_16().named_modules()])
+
+ - use ``pool=""`` or set it to a tuple of pooling parameters to indicate modifications of both
+ the pooling and the FC layer. It should be used with ``node_name`` to locate the model feature outputs:
+ In this case, the class will load a torchvision model, remove the existing pooling and FC layers, and
+
+ - append an additional convolution layer:
+ ``use_conv=True, pool="", node_name="permute"``
+ - append an additional pooling and FC layers:
+ ``use_conv=False, pool=("avg", {"kernel_size": 7, "stride": 1}), node_name="permute"``
+ - append an additional pooling and convolution layers:
+ ``use_conv=True, pool=("avg", {"kernel_size": 7, "stride": 1}), node_name="permute"``
+
+ The ``permute`` in the example is the target feature extraction node of the input
+ `model_name`, could be found by using the torchvision feature extraction utilities, for example::
+
+ from torchvision.models.feature_extraction import get_graph_node_names
+ from torchvision.models import swin_t
+ print(get_graph_node_names(swin_t())[0])
Args:
model_name: name of any torchvision model with fully connected layer at the end.
``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``,
- ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``.
+ ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``, ``inception_v3``.
model details: https://pytorch.org/vision/stable/models.html.
num_classes: number of classes for the last classification layer. Default to 1.
dim: number of supported spatial dimensions in the specified model, depends on the model implementation.
default to 2 as most Torchvision models are for 2D image processing.
in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
- use_conv: whether use convolutional layer to replace the last layer, default to False.
- pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
- the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
- default to `("avg", {"kernel_size": 7, "stride": 1})`.
+ use_conv: whether to use convolutional layer to replace the last layer, default to False.
+ pool: parameters for the pooling layer, when it's a tuple, the first item is name of the pooling layer,
+ the second item is dictionary of the initialization args. If None, will not replace the `layers[-2]`.
+ default to `("avg", {"kernel_size": 7, "stride": 1})`. ``""`` indicates not adding a pooling layer.
bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
default to True.
pretrained: whether to use the imagenet pretrained weights. Default to False.
+ fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``"fc"``.
+ node_name: the corresponding feature extractor node name of `model`. Defaults to "", not in use.
+ weights: additional weights enum for the torchvision model.
+ kwargs: additional parameters for the torchvision model.
+
+ Example::
+
+ import torch
+ from torchvision.models.inception import Inception_V3_Weights
+
+ from monai.networks.nets import TorchVisionFCModel
+
+ model = TorchVisionFCModel(
+ "inception_v3",
+ num_classes=4,
+ weights=Inception_V3_Weights.IMAGENET1K_V1,
+ use_conv=False,
+ pool=None,
+ )
+ # model = TorchVisionFCModel("vit_b_16", num_classes=4, pool=None, in_channels=768, fc_name="heads")
+ output = model.forward(torch.randn(2, 3, 299, 299))
+ print(output.shape) # torch.Size([2, 4])
+
"""
- @deprecated_arg("n_classes", since="0.6")
def __init__(
self,
model_name: str = "resnet18",
@@ -53,15 +105,15 @@ def __init__(
pool: Optional[Tuple[str, Dict[str, Any]]] = ("avg", {"kernel_size": 7, "stride": 1}),
bias: bool = True,
pretrained: bool = False,
- n_classes: Optional[int] = None,
+ fc_name: str = "fc",
+ node_name: str = "",
+ weights=None,
+ **kwargs,
):
- # in case the new num_classes is default but you still call deprecated n_classes
- if n_classes is not None and num_classes == 1:
- num_classes = n_classes
- model = getattr(models, model_name)(pretrained=pretrained)
- # check if the model is compatible, should have a FC layer at the end
- if not str(list(model.children())[-1]).startswith("Linear"):
- raise ValueError(f"Model ['{model_name}'] does not have a Linear layer at the end.")
+ if weights is not None:
+ model = getattr(models, model_name)(weights=weights, **kwargs)
+ else:
+ model = getattr(models, model_name)(pretrained=pretrained, **kwargs) # 'pretrained' deprecated 0.13
super().__init__(
model=model,
@@ -71,4 +123,6 @@ def __init__(
use_conv=use_conv,
pool=pool,
bias=bias,
+ fc_name=fc_name,
+ node_name=node_name,
)
diff --git a/monai/networks/nets/transchex.py b/monai/networks/nets/transchex.py
index b03ff5a17d2..e16858368af 100644
--- a/monai/networks/nets/transchex.py
+++ b/monai/networks/nets/transchex.py
@@ -72,7 +72,26 @@ def from_pretrained(
else:
tempdir = tempfile.mkdtemp()
with tarfile.open(resolved_archive_file, "r:gz") as archive:
- archive.extractall(tempdir)
+
+ def is_within_directory(directory, target):
+
+ abs_directory = os.path.abspath(directory)
+ abs_target = os.path.abspath(target)
+
+ prefix = os.path.commonprefix([abs_directory, abs_target])
+
+ return prefix == abs_directory
+
+ def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
+
+ for member in tar.getmembers():
+ member_path = os.path.join(path, member.name)
+ if not is_within_directory(path, member_path):
+ raise Exception("Attempted Path Traversal in Tar File")
+
+ tar.extractall(path, members, numeric_owner=numeric_owner)
+
+ safe_extract(archive, tempdir)
serialization_dir = tempdir
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
if state_dict is None and not from_tf:
diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py
index faccddee455..6db41ef8fbd 100644
--- a/monai/networks/nets/unet.py
+++ b/monai/networks/nets/unet.py
@@ -197,7 +197,7 @@ def _create_block(
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
"""
Returns the block object defining a layer of the UNet structure including the implementation of the skip
- between encoding (down) and and decoding (up) sides of the network.
+ between encoding (down) and decoding (up) sides of the network.
Args:
down_path: encoding half of the layer
diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py
index c53936d27fe..f95428693a1 100644
--- a/monai/networks/nets/unetr.py
+++ b/monai/networks/nets/unetr.py
@@ -40,6 +40,7 @@ def __init__(
res_block: bool = True,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
+ qkv_bias: bool = False,
) -> None:
"""
Args:
@@ -56,6 +57,7 @@ def __init__(
res_block: bool argument to determine if residual block is used.
dropout_rate: faction of the input units to drop.
spatial_dims: number of spatial dims.
+ qkv_bias: apply the bias term for the qkv linear layer in self attention block
Examples::
@@ -96,6 +98,7 @@ def __init__(
classification=self.classification,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
+ qkv_bias=qkv_bias,
)
self.encoder1 = UnetrBasicBlock(
spatial_dims=spatial_dims,
diff --git a/monai/networks/nets/varautoencoder.py b/monai/networks/nets/varautoencoder.py
index 7386883124e..31c2a5cfe62 100644
--- a/monai/networks/nets/varautoencoder.py
+++ b/monai/networks/nets/varautoencoder.py
@@ -19,7 +19,6 @@
from monai.networks.layers.convutils import calculate_out_shape, same_padding
from monai.networks.layers.factories import Act, Norm
from monai.networks.nets import AutoEncoder
-from monai.utils import deprecated_arg
__all__ = ["VarAutoEncoder"]
@@ -49,9 +48,7 @@ class VarAutoEncoder(AutoEncoder):
bias: whether to have a bias term in convolution blocks. Defaults to True.
According to `Performance Tuning Guide `_,
if a conv layer is directly followed by a batch norm layer, bias should be False.
-
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
+ use_sigmoid: whether to use the sigmoid function on final output. Defaults to True.
Examples::
@@ -72,9 +69,6 @@ class VarAutoEncoder(AutoEncoder):
https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb
"""
- @deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def __init__(
self,
spatial_dims: int,
@@ -93,15 +87,14 @@ def __init__(
norm: Union[Tuple, str] = Norm.INSTANCE,
dropout: Optional[Union[Tuple, str, float]] = None,
bias: bool = True,
- dimensions: Optional[int] = None,
+ use_sigmoid: bool = True,
) -> None:
self.in_channels, *self.in_shape = in_shape
+ self.use_sigmoid = use_sigmoid
self.latent_size = latent_size
self.final_size = np.asarray(self.in_shape, dtype=int)
- if dimensions is not None:
- spatial_dims = dimensions
super().__init__(
spatial_dims,
@@ -158,4 +151,4 @@ def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
mu, logvar = self.encode_forward(x)
z = self.reparameterize(mu, logvar)
- return self.decode_forward(z), mu, logvar, z
+ return self.decode_forward(z, self.use_sigmoid), mu, logvar, z
diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py
index d41ba447e92..e4166c78b6c 100644
--- a/monai/networks/nets/vit.py
+++ b/monai/networks/nets/vit.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Sequence, Union
import torch
@@ -44,6 +43,7 @@ def __init__(
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation="Tanh",
+ qkv_bias: bool = False,
) -> None:
"""
Args:
@@ -61,6 +61,7 @@ def __init__(
spatial_dims: number of spatial dimensions.
post_activation: add a final acivation function to the classification head when `classification` is True.
Default to "Tanh" for `nn.Tanh()`. Set to other values to remove this function.
+ qkv_bias: apply bias to the qkv linear layer in self attention block
Examples::
@@ -95,7 +96,7 @@ def __init__(
spatial_dims=spatial_dims,
)
self.blocks = nn.ModuleList(
- [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
+ [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias) for i in range(num_layers)]
)
self.norm = nn.LayerNorm(hidden_size)
if self.classification:
diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py
index 9e5490f9d66..6197f6bd992 100644
--- a/monai/networks/nets/vitautoenc.py
+++ b/monai/networks/nets/vitautoenc.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Sequence, Union
import torch
diff --git a/monai/networks/nets/vnet.py b/monai/networks/nets/vnet.py
index 7669b4678ee..9abd2bc5e20 100644
--- a/monai/networks/nets/vnet.py
+++ b/monai/networks/nets/vnet.py
@@ -67,16 +67,19 @@ def __init__(
):
super().__init__()
- if 16 % in_channels != 0:
- raise ValueError(f"16 should be divisible by in_channels, got in_channels={in_channels}.")
+ if out_channels % in_channels != 0:
+ raise ValueError(
+ f"out channels should be divisible by in_channels. Got in_channels={in_channels}, out_channels={out_channels}."
+ )
self.spatial_dims = spatial_dims
self.in_channels = in_channels
- self.act_function = get_acti_layer(act, 16)
+ self.out_channels = out_channels
+ self.act_function = get_acti_layer(act, out_channels)
self.conv_block = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
- out_channels=16,
+ out_channels=out_channels,
kernel_size=5,
act=None,
norm=Norm.BATCH,
@@ -85,7 +88,7 @@ def __init__(
def forward(self, x):
out = self.conv_block(x)
- repeat_num = 16 // self.in_channels
+ repeat_num = self.out_channels // self.in_channels
x16 = x.repeat([1, repeat_num, 1, 1, 1][: self.spatial_dims + 2])
out = self.act_function(torch.add(out, x16))
return out
diff --git a/monai/networks/utils.py b/monai/networks/utils.py
index f9018b0c391..b353168a8ab 100644
--- a/monai/networks/utils.py
+++ b/monai/networks/utils.py
@@ -21,9 +21,12 @@
import torch
import torch.nn as nn
+from monai.apps.utils import get_logger
from monai.config import PathLike
-from monai.utils.deprecate_utils import deprecated, deprecated_arg
+from monai.utils.deprecate_utils import deprecated
from monai.utils.misc import ensure_tuple, save_obj, set_determinism
+from monai.utils.module import look_up_option, pytorch_after
+from monai.utils.type_conversion import convert_to_tensor
__all__ = [
"one_hot",
@@ -41,10 +44,69 @@
"save_state",
"convert_to_torchscript",
"meshgrid_ij",
+ "meshgrid_xy",
"replace_modules",
"replace_modules_temp",
+ "look_up_named_module",
+ "set_named_module",
]
+logger = get_logger(module_name=__name__)
+
+
+def look_up_named_module(name: str, mod, print_all_options=False):
+ """
+ get the named module in `mod` by the attribute name,
+ for example ``look_up_named_module(net, "features.3.1.attn")``
+
+ Args:
+ name: a string representing the module attribute.
+ mod: a pytorch module to be searched (in ``mod.named_modules()``).
+ print_all_options: whether to print all named modules when `name` is not found in `mod`. Defaults to False.
+
+ Returns:
+ the corresponding pytorch module's subcomponent such as ``net.features[3][1].attn``
+ """
+ name_str = look_up_option(
+ name, {n[0] for n in mod.named_modules()}, default=None, print_all_options=print_all_options
+ )
+ if name_str is None:
+ return None
+ if name_str == "":
+ return mod
+ for n in name_str.split("."):
+ if n.isdigit():
+ mod = mod[int(n)]
+ else:
+ n = look_up_option(n, {item[0] for item in mod.named_modules()}, default=None, print_all_options=False)
+ if n is None:
+ return None
+ mod = getattr(mod, n)
+ return mod
+
+
+def set_named_module(mod, name: str, new_layer):
+ """
+ look up `name` in `mod` and replace the layer with `new_layer`, return the updated `mod`.
+
+ Args:
+ mod: a pytorch module to be updated.
+ name: a string representing the target module attribute.
+ new_layer: a new module replacing the corresponding layer at ``mod.name``.
+
+ Returns:
+ an updated ``mod``
+
+ See also: :py:func:`monai.networks.utils.look_up_named_module`.
+ """
+ mods_attr = name.rsplit(".", 1)
+ submods, attr = mods_attr if len(mods_attr) == 2 else ("", name)
+ if not attr:
+ return new_layer
+ _mod = look_up_named_module(submods, mod)
+ setattr(_mod, attr, new_layer)
+ return mod
+
def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
"""
@@ -133,7 +195,7 @@ def predict_segmentation(logits: torch.Tensor, mutually_exclusive: bool = False,
def normalize_transform(
- shape: Sequence[int],
+ shape,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
align_corners: bool = False,
@@ -150,7 +212,7 @@ def normalize_transform(
- `align_corners=True`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``.
Args:
- shape: input spatial shape
+ shape: input spatial shape, a sequence of integers.
device: device on which the returned affine will be allocated.
dtype: data type of the returned affine
align_corners: if True, consider -1 and 1 to refer to the centers of the
@@ -159,7 +221,8 @@ def normalize_transform(
zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
Setting this flag and `align_corners` will jointly specify the normalization source range.
"""
- norm = torch.tensor(shape, dtype=torch.float64, device=device) # no in-place change
+ shape = convert_to_tensor(shape, torch.float64, device=device, wrap_sequence=True, track_meta=False)
+ norm = shape.clone().detach().to(dtype=torch.float64, device=device) # no in-place change
if align_corners:
norm[norm <= 1.0] = 2.0
norm = 2.0 / (norm - 1.0)
@@ -170,10 +233,10 @@ def normalize_transform(
norm[norm <= 0.0] = 2.0
norm = 2.0 / norm
norm = torch.diag(torch.cat((norm, torch.ones((1,), dtype=torch.float64, device=device))))
- norm[:-1, -1] = 1.0 / torch.tensor(shape, dtype=torch.float64, device=device) - (0.0 if zero_centered else 1.0)
+ norm[:-1, -1] = 1.0 / shape - (0.0 if zero_centered else 1.0)
norm = norm.unsqueeze(0).to(dtype=dtype)
norm.requires_grad = False
- return norm
+ return norm # type: ignore
def to_norm_affine(
@@ -257,12 +320,7 @@ def icnr_init(conv, upsample_factor, init=nn.init.kaiming_normal_):
conv.weight.data.copy_(kernel)
-@deprecated_arg(
- name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
-)
-def pixelshuffle(
- x: torch.Tensor, spatial_dims: int, scale_factor: int, dimensions: Optional[int] = None
-) -> torch.Tensor:
+def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor:
"""
Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
@@ -276,17 +334,12 @@ def pixelshuffle(
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
scale_factor: factor to rescale the spatial dimensions by, must be >=1
- .. deprecated:: 0.6.0
- ``dimensions`` is deprecated, use ``spatial_dims`` instead.
-
Returns:
Reshuffled version of `x`.
Raises:
ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims)
"""
- if dimensions is not None:
- spatial_dims = dimensions
dim, factor = spatial_dims, scale_factor
input_size = list(x.size())
batch_size, channels = input_size[:2]
@@ -461,8 +514,10 @@ def copy_model_state(
updated_keys = sorted(set(updated_keys))
unchanged_keys = sorted(set(all_keys).difference(updated_keys))
- print(f"'dst' model updated: {len(updated_keys)} of {len(dst_dict)} variables.")
+ logger.info(f"'dst' model updated: {len(updated_keys)} of {len(dst_dict)} variables.")
if inplace and isinstance(dst, torch.nn.Module):
+ if isinstance(dst, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ dst = dst.module
dst.load_state_dict(dst_dict)
return dst_dict, updated_keys, unchanged_keys
@@ -559,7 +614,8 @@ def convert_to_torchscript(
# compare TorchScript and PyTorch results
for r1, r2 in zip(torch_out, torchscript_out):
if isinstance(r1, torch.Tensor) or isinstance(r2, torch.Tensor):
- torch.testing.assert_allclose(r1, r2, rtol=rtol, atol=atol)
+ assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose
+ assert_fn(r1, r2, rtol=rtol, atol=atol)
return script_module
@@ -567,9 +623,17 @@ def convert_to_torchscript(
def meshgrid_ij(*tensors):
if torch.meshgrid.__kwdefaults__ is not None and "indexing" in torch.meshgrid.__kwdefaults__:
return torch.meshgrid(*tensors, indexing="ij") # new api pytorch after 1.10
+
return torch.meshgrid(*tensors)
+def meshgrid_xy(*tensors):
+ if torch.meshgrid.__kwdefaults__ is not None and "indexing" in torch.meshgrid.__kwdefaults__:
+ return torch.meshgrid(*tensors, indexing="xy") # new api pytorch after 1.10
+
+ return torch.meshgrid(tensors[1], tensors[0], *tensors[2:])
+
+
def _replace_modules(
parent: torch.nn.Module,
name: str,
diff --git a/monai/optimizers/lr_scheduler.py b/monai/optimizers/lr_scheduler.py
index 83412c61ea6..cb047f8bc5e 100644
--- a/monai/optimizers/lr_scheduler.py
+++ b/monai/optimizers/lr_scheduler.py
@@ -62,7 +62,13 @@ class WarmupCosineSchedule(LambdaLR):
"""
def __init__(
- self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1
+ self,
+ optimizer: Optimizer,
+ warmup_steps: int,
+ t_total: int,
+ cycles: float = 0.5,
+ last_epoch: int = -1,
+ warmup_multiplier: float = 0,
) -> None:
"""
Args:
@@ -71,16 +77,22 @@ def __init__(
t_total: total number of training iterations.
cycles: cosine cycles parameter.
last_epoch: the index of last epoch.
+ warmup_multiplier: if provided, starts the linear warmup from this fraction of the intial lr.
+ Must be in 0..1 interval. Defaults to 0
Returns:
None
"""
- self.warmup_steps = warmup_steps
+ self.warmup_steps = min(max(warmup_steps, 0), t_total)
+ self.warmup_multiplier = warmup_multiplier
self.t_total = t_total
self.cycles = cycles
+ if warmup_multiplier < 0 or warmup_multiplier > 1:
+ raise ValueError("warmup_multiplier must be in 0..1 range")
super().__init__(optimizer, self.lr_lambda, last_epoch)
def lr_lambda(self, step):
if step < self.warmup_steps:
- return float(step) / float(max(1.0, self.warmup_steps))
+ f = float(step) / float(max(1.0, self.warmup_steps))
+ return self.warmup_multiplier + (1 - self.warmup_multiplier) * f
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py
index bcb2f566705..9cabc167a7f 100644
--- a/monai/transforms/__init__.py
+++ b/monai/transforms/__init__.py
@@ -89,6 +89,7 @@
)
from .intensity.array import (
AdjustContrast,
+ ComputeHoVerMaps,
DetectEnvelope,
ForegroundMask,
GaussianSharpen,
@@ -98,6 +99,7 @@
IntensityRemap,
KSpaceSpikeNoise,
MaskIntensity,
+ MedianSmooth,
NormalizeIntensity,
RandAdjustContrast,
RandBiasField,
@@ -127,6 +129,9 @@
AdjustContrastd,
AdjustContrastD,
AdjustContrastDict,
+ ComputeHoVerMapsd,
+ ComputeHoVerMapsD,
+ ComputeHoVerMapsDict,
ForegroundMaskd,
ForegroundMaskD,
ForegroundMaskDict,
@@ -148,6 +153,9 @@
MaskIntensityd,
MaskIntensityD,
MaskIntensityDict,
+ MedianSmoothd,
+ MedianSmoothD,
+ MedianSmoothDict,
NormalizeIntensityd,
NormalizeIntensityD,
NormalizeIntensityDict,
@@ -263,6 +271,8 @@
LabelToContour,
MeanEnsemble,
ProbNMS,
+ RemoveSmallObjects,
+ SobelGradients,
VoteEnsemble,
)
from .post.dictionary import (
@@ -296,13 +306,32 @@
ProbNMSD,
ProbNMSd,
ProbNMSDict,
+ RemoveSmallObjectsD,
+ RemoveSmallObjectsd,
+ RemoveSmallObjectsDict,
SaveClassificationD,
SaveClassificationd,
SaveClassificationDict,
+ SobelGradientsd,
+ SobelGradientsD,
+ SobelGradientsDict,
VoteEnsembleD,
VoteEnsembled,
VoteEnsembleDict,
)
+from .signal.array import (
+ SignalContinuousWavelet,
+ SignalFillEmpty,
+ SignalRandAddGaussianNoise,
+ SignalRandAddSine,
+ SignalRandAddSinePartial,
+ SignalRandAddSquarePulse,
+ SignalRandAddSquarePulsePartial,
+ SignalRandDrop,
+ SignalRandScale,
+ SignalRandShift,
+ SignalRemoveFrequency,
+)
from .smooth_field.array import (
RandSmoothDeform,
RandSmoothFieldAdjustContrast,
@@ -420,7 +449,18 @@
ZoomD,
ZoomDict,
)
-from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
+from .transform import (
+ LazyTrait,
+ LazyTransform,
+ MapTransform,
+ MultiSampleTrait,
+ Randomizable,
+ RandomizableTrait,
+ RandomizableTransform,
+ ThreadUnsafe,
+ Transform,
+ apply_transform,
+)
from .utility.array import (
AddChannel,
AddCoordinateChannels,
@@ -605,6 +645,7 @@
map_spatial_axes,
print_transform_backends,
rand_choice,
+ remove_small_objects,
rescale_array,
rescale_array_int_max,
rescale_instance_array,
diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py
index 92fd11cf79d..1edbcc63e29 100644
--- a/monai/transforms/adaptors.py
+++ b/monai/transforms/adaptors.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
How to use the adaptor function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py
index 1d60c34c3e4..7b55a993a19 100644
--- a/monai/transforms/compose.py
+++ b/monai/transforms/compose.py
@@ -17,6 +17,7 @@
import numpy as np
+import monai
from monai.transforms.inverse import InvertibleTransform
# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
@@ -254,26 +255,27 @@ def __call__(self, data):
_transform = self.transforms[index]
data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats)
# if the data is a mapping (dictionary), append the OneOf transform to the end
- if isinstance(data, Mapping):
- for key in data.keys():
- if self.trace_key(key) in data:
+ if isinstance(data, monai.data.MetaTensor):
+ self.push_transform(data, extra_info={"index": index})
+ elif isinstance(data, Mapping):
+ for key in data: # dictionary not change size during iteration
+ if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
self.push_transform(data, key, extra_info={"index": index})
return data
def inverse(self, data):
if len(self.transforms) == 0:
return data
- if not isinstance(data, Mapping):
- raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")
- # loop until we get an index and then break (since they'll all be the same)
index = None
- for key in data.keys():
- if self.trace_key(key) in data:
- # get the index of the applied OneOf transform
- index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
- # and then remove the OneOf transform
- self.pop_transform(data, key)
+ if isinstance(data, monai.data.MetaTensor):
+ index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"]
+ elif isinstance(data, Mapping):
+ for key in data:
+ if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
+ index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
+ else:
+ raise RuntimeError("Inverse only implemented for Mapping (dictionary) or MetaTensor data.")
if index is None:
# no invertible transforms have been applied
return data
diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py
index bdf9d07f24b..9a773c43693 100644
--- a/monai/transforms/croppad/array.py
+++ b/monai/transforms/croppad/array.py
@@ -198,7 +198,7 @@ def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]):
spatial_rank = max(len(tensor.affine) - 1, 1)
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
mat = create_translate(spatial_rank, to_shift)
- tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
+ tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
def inverse(self, data: MetaTensor) -> MetaTensor:
transform = self.pop_transform(data)
@@ -363,7 +363,7 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int
class Crop(InvertibleTransform):
"""
- Perform crop operation on the input image.
+ Perform crop operations on the input image.
"""
@@ -378,7 +378,7 @@ def compute_slices(
roi_slices: Optional[Sequence[slice]] = None,
):
"""
- Compute the crop slices based on specified `center & size` or `start & end`.
+ Compute the crop slices based on specified `center & size` or `start & end` or `slices`.
Args:
roi_center: voxel coordinates for center of the crop ROI.
@@ -450,7 +450,7 @@ def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]):
spatial_rank = max(len(tensor.affine) - 1, 1)
to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]]
mat = create_translate(spatial_rank, to_shift)
- tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
+ tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
def inverse(self, img: MetaTensor) -> MetaTensor:
transform = self.pop_transform(img)
@@ -471,7 +471,7 @@ class SpatialCrop(Crop):
It can support to crop ND spatial (channel-first) data.
The cropped region can be parameterised in various ways:
- - a list of slices for each spatial dimension (allows for use of -ve indexing and `None`)
+ - a list of slices for each spatial dimension (allows for use of negative indexing and `None`)
- a spatial center and size
- the start and end coordinates of the ROI
"""
diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py
index 3fa3aa63fbc..16d454df103 100644
--- a/monai/transforms/intensity/array.py
+++ b/monai/transforms/intensity/array.py
@@ -26,11 +26,10 @@
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_track_meta
from monai.data.utils import get_random_patch, get_valid_patch_size
-from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
+from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
-from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TransformBackends
from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
from monai.utils.module import min_version, optional_import
@@ -57,6 +56,7 @@
"MaskIntensity",
"DetectEnvelope",
"SavitzkyGolaySmooth",
+ "MedianSmooth",
"GaussianSmooth",
"RandGaussianSmooth",
"GaussianSharpen",
@@ -73,6 +73,7 @@
"IntensityRemap",
"RandIntensityRemap",
"ForegroundMask",
+ "ComputeHoVerMaps",
]
@@ -189,14 +190,13 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
"""
Apply the transform to `img`.
"""
- img = convert_to_tensor(img, track_meta=get_track_meta())
+ img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
super().randomize(None)
if not self._do_transform:
return img
- img, *_ = convert_data_type(img, dtype=self.dtype)
if self.channel_wise:
_mean = ensure_tuple_rep(self.mean, len(img))
_std = ensure_tuple_rep(self.std, len(img))
@@ -335,9 +335,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
- img = convert_to_tensor(img, track_meta=get_track_meta())
- if self.dtype is not None:
- img, *_ = convert_data_type(img, dtype=self.dtype)
+ img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if self.channel_wise:
for i, d in enumerate(img):
img[i] = self._stdshift(d) # type: ignore
@@ -394,7 +392,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
"""
Apply the transform to `img`.
"""
- img = convert_to_tensor(img, track_meta=get_track_meta())
+ img = convert_to_tensor(img, track_meta=get_track_meta(), dtype=self.dtype)
if randomize:
self.randomize()
@@ -506,7 +504,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
self.randomize()
if not self._do_transform:
- return img
+ return convert_data_type(img, dtype=self.dtype)[0]
return ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)(img)
@@ -1139,6 +1137,35 @@ def __call__(self, img: NdarrayOrTensor):
return out
+class MedianSmooth(Transform):
+ """
+ Apply median filter to the input data based on specified `radius` parameter.
+ A default value `radius=1` is provided for reference.
+
+ See also: :py:func:`monai.networks.layers.median_filter`
+
+ Args:
+ radius: if a list of values, must match the count of spatial dimensions of input data,
+ and apply every value in the list to 1 spatial dimension. if only 1 value provided,
+ use it for all spatial dimensions.
+ """
+
+ backend = [TransformBackends.TORCH]
+
+ def __init__(self, radius: Union[Sequence[int], int] = 1) -> None:
+ self.radius = radius
+
+ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
+ img = convert_to_tensor(img, track_meta=get_track_meta())
+ img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
+ spatial_dims = img_t.ndim - 1
+ r = ensure_tuple_rep(self.radius, spatial_dims)
+ median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims)
+ out_t: torch.Tensor = median_filter_instance(img_t)
+ out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)
+ return out
+
+
class GaussianSmooth(Transform):
"""
Apply Gaussian smooth to the input data based on specified `sigma` parameter.
@@ -1467,8 +1494,7 @@ class GibbsNoise(Transform, Fourier):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- @deprecated_arg(name="as_tensor_output", since="0.6")
- def __init__(self, alpha: float = 0.1, as_tensor_output: bool = True) -> None:
+ def __init__(self, alpha: float = 0.1) -> None:
if alpha > 1 or alpha < 0:
raise ValueError("alpha must take values in the interval [0, 1].")
@@ -1546,8 +1572,7 @@ class RandGibbsNoise(RandomizableTransform):
backend = GibbsNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
- def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0), as_tensor_output: bool = True) -> None:
+ def __init__(self, prob: float = 0.1, alpha: Sequence[float] = (0.0, 1.0)) -> None:
if len(alpha) != 2:
raise ValueError("alpha length must be 2.")
if alpha[1] > 1 or alpha[0] < 0:
@@ -1619,13 +1644,7 @@ class KSpaceSpikeNoise(Transform, Fourier):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- @deprecated_arg(name="as_tensor_output", since="0.6")
- def __init__(
- self,
- loc: Union[Tuple, Sequence[Tuple]],
- k_intensity: Optional[Union[Sequence[float], float]] = None,
- as_tensor_output: bool = True,
- ):
+ def __init__(self, loc: Union[Tuple, Sequence[Tuple]], k_intensity: Optional[Union[Sequence[float], float]] = None):
self.loc = ensure_tuple(loc)
self.k_intensity = k_intensity
@@ -1754,13 +1773,11 @@ class RandKSpaceSpikeNoise(RandomizableTransform, Fourier):
backend = KSpaceSpikeNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
prob: float = 0.1,
intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None,
channel_wise: bool = True,
- as_tensor_output: bool = True,
):
self.intensity_range = intensity_range
@@ -2301,3 +2318,44 @@ def __call__(self, image: NdarrayOrTensor):
mask = np.stack(foregrounds).all(axis=0)
return convert_to_dst_type(src=mask, dst=image)[0]
+
+
+class ComputeHoVerMaps(Transform):
+ """Compute horizontal and vertical maps from an instance mask
+ It generates normalized horizontal and vertical distances to the center of mass of each region.
+ Input data with the size of [1xHxW[xD]], which channel dim will temporarily removed for calculating coordinates.
+
+ Args:
+ dtype: the data type of output Tensor. Defaults to `"float32"`.
+
+ Return:
+ A torch.Tensor with the size of [2xHxW[xD]], which is stack horizontal and vertical maps
+
+ """
+
+ def __init__(self, dtype: DtypeLike = "float32") -> None:
+ super().__init__()
+ self.dtype = dtype
+
+ def __call__(self, mask: NdarrayOrTensor):
+ instance_mask = convert_data_type(mask, np.ndarray)[0]
+
+ h_map = instance_mask.astype(self.dtype, copy=True)
+ v_map = instance_mask.astype(self.dtype, copy=True)
+ instance_mask = instance_mask.squeeze(0) # remove channel dim
+
+ for region in skimage.measure.regionprops(instance_mask):
+ v_dist = region.coords[:, 0] - region.centroid[0]
+ h_dist = region.coords[:, 1] - region.centroid[1]
+
+ h_dist[h_dist < 0] /= -np.amin(h_dist)
+ h_dist[h_dist > 0] /= np.amax(h_dist)
+
+ v_dist[v_dist < 0] /= -np.amin(v_dist)
+ v_dist[v_dist > 0] /= np.amax(v_dist)
+
+ h_map[h_map == region.label] = h_dist
+ v_map[v_map == region.label] = v_dist
+
+ hv_maps = convert_to_tensor(np.concatenate([h_map, v_map]), track_meta=get_track_meta())
+ return hv_maps
diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py
index b9308255cfd..efdfd502ca9 100644
--- a/monai/transforms/intensity/dictionary.py
+++ b/monai/transforms/intensity/dictionary.py
@@ -15,7 +15,7 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""
-from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
+from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
@@ -24,6 +24,7 @@
from monai.data.meta_obj import get_track_meta
from monai.transforms.intensity.array import (
AdjustContrast,
+ ComputeHoVerMaps,
ForegroundMask,
GaussianSharpen,
GaussianSmooth,
@@ -31,6 +32,7 @@
HistogramNormalize,
KSpaceSpikeNoise,
MaskIntensity,
+ MedianSmooth,
NormalizeIntensity,
RandAdjustContrast,
RandBiasField,
@@ -57,7 +59,6 @@
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import convert_to_tensor, ensure_tuple, ensure_tuple_rep
-from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import PostFix
__all__ = [
@@ -78,6 +79,7 @@
"ScaleIntensityRangePercentilesd",
"MaskIntensityd",
"SavitzkyGolaySmoothd",
+ "MedianSmoothd",
"GaussianSmoothd",
"RandGaussianSmoothd",
"GaussianSharpend",
@@ -91,6 +93,7 @@
"RandCoarseShuffled",
"HistogramNormalized",
"ForegroundMaskd",
+ "ComputeHoVerMapsd",
"RandGaussianNoiseD",
"RandGaussianNoiseDict",
"ShiftIntensityD",
@@ -123,6 +126,8 @@
"MaskIntensityDict",
"SavitzkyGolaySmoothD",
"SavitzkyGolaySmoothDict",
+ "MedianSmoothD",
+ "MedianSmoothDict",
"GaussianSmoothD",
"GaussianSmoothDict",
"RandGaussianSmoothD",
@@ -151,6 +156,8 @@
"RandKSpaceSpikeNoiseDict",
"ForegroundMaskD",
"ForegroundMaskDict",
+ "ComputeHoVerMapsD",
+ "ComputeHoVerMapsDict",
]
DEFAULT_POST_FIX = PostFix.meta()
@@ -203,13 +210,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d
# all the keys share the same random noise
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
- self.rand_gaussian_noise.randomize(d[first_key]) # type: ignore
+ self.rand_gaussian_noise.randomize(d[first_key])
+
for key in self.key_iterator(d):
d[key] = self.rand_gaussian_noise(img=d[key], randomize=False)
return d
@@ -241,7 +249,6 @@ class RandRicianNoised(RandomizableTransform, MapTransform):
backend = RandRicianNoise.backend
- @deprecated_arg("global_prob", since="0.7")
def __init__(
self,
keys: KeysCollection,
@@ -380,8 +387,8 @@ def __init__(
meta_key_postfix: if meta_keys is None, use `key_{postfix}` to fetch the metadata according
to the key data, default is `meta_dict`, the metadata is a dictionary object.
used to extract the factor value is `factor_key` is not None.
- prob: probability of rotating.
- (Default 0.1, with 10% probability it returns a rotated array.)
+ prob: probability of shift.
+ (Default 0.1, with 10% probability it returns an array shifted intensity.)
allow_missing_keys: don't raise exception if key is missing.
"""
MapTransform.__init__(self, keys, allow_missing_keys)
@@ -578,8 +585,8 @@ def __init__(
See also: :py:class:`monai.transforms.compose.MapTransform`
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
- prob: probability of rotating.
- (Default 0.1, with 10% probability it returns a rotated array.)
+ prob: probability of scale.
+ (Default 0.1, with 10% probability it returns a scaled array.)
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
@@ -659,13 +666,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d
# all the keys share the same random bias factor
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
- self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) # type: ignore
+ self.rand_bias_field.randomize(img_size=d[first_key].shape[1:])
+
for key in self.key_iterator(d):
d[key] = self.rand_bias_field(d[key], randomize=False)
return d
@@ -984,6 +992,35 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d
+class MedianSmoothd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.MedianSmooth`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ See also: :py:class:`monai.transforms.compose.MapTransform`
+ radius: if a list of values, must match the count of spatial dimensions of input data,
+ and apply every value in the list to 1 spatial dimension. if only 1 value provided,
+ use it for all spatial dimensions.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = MedianSmooth.backend
+
+ def __init__(
+ self, keys: KeysCollection, radius: Union[Sequence[int], int], allow_missing_keys: bool = False
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = MedianSmooth(radius)
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.converter(d[key])
+ return d
+
+
class GaussianSmoothd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.GaussianSmooth`.
@@ -1196,7 +1233,7 @@ def __call__(self, data: Dict[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Ndar
class RandHistogramShiftd(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandHistogramShift`.
- Apply random nonlinear transform the the image's intensity histogram.
+ Apply random nonlinear transform the image's intensity histogram.
Args:
keys: keys of the corresponding items to be transformed.
@@ -1270,14 +1307,12 @@ class RandGibbsNoised(RandomizableTransform, MapTransform):
backend = RandGibbsNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
prob: float = 0.1,
alpha: Sequence[float] = (0.0, 1.0),
allow_missing_keys: bool = False,
- as_tensor_output: bool = True,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
@@ -1327,10 +1362,7 @@ class GibbsNoised(MapTransform):
backend = GibbsNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
- def __init__(
- self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False, as_tensor_output: bool = True
- ) -> None:
+ def __init__(self, keys: KeysCollection, alpha: float = 0.5, allow_missing_keys: bool = False) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
self.transform = GibbsNoise(alpha)
@@ -1386,14 +1418,12 @@ class KSpaceSpikeNoised(MapTransform):
backend = KSpaceSpikeNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
loc: Union[Tuple, Sequence[Tuple]],
k_intensity: Optional[Union[Sequence[float], float]] = None,
allow_missing_keys: bool = False,
- as_tensor_output: bool = True,
) -> None:
super().__init__(keys, allow_missing_keys)
@@ -1450,21 +1480,13 @@ class RandKSpaceSpikeNoised(RandomizableTransform, MapTransform):
backend = RandKSpaceSpikeNoise.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
- @deprecated_arg(name="common_sampling", since="0.6")
- @deprecated_arg(name="common_seed", since="0.6")
- @deprecated_arg(name="global_prob", since="0.6")
def __init__(
self,
keys: KeysCollection,
- global_prob: float = 1.0,
prob: float = 0.1,
intensity_range: Optional[Sequence[Union[Sequence[float], float]]] = None,
channel_wise: bool = True,
- common_sampling: bool = False,
- common_seed: int = 42,
allow_missing_keys: bool = False,
- as_tensor_output: bool = True,
):
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob=prob)
@@ -1564,8 +1586,8 @@ def __call__(self, data):
return d
# expect all the specified keys have same spatial shape and share same random holes
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
@@ -1637,8 +1659,8 @@ def __call__(self, data):
return d
# expect all the specified keys have same spatial shape and share same random holes
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d
@@ -1741,6 +1763,39 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d
+class ComputeHoVerMapsd(MapTransform):
+ """Compute horizontal and vertical maps from an instance mask
+ It generates normalized horizontal and vertical distances to the center of mass of each region.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ dtype: the type of output Tensor. Defaults to `"float32"`.
+ new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
+ key intact. Defaults to '"_hover", so if the input key is "mask" the output will be "hover_mask".
+ allow_missing_keys: do not raise exception if key is missing.
+
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ dtype: DtypeLike = "float32",
+ new_key_prefix: str = "hover_",
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.transform = ComputeHoVerMaps(dtype=dtype)
+ self.new_key_prefix = new_key_prefix
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ new_key = key if self.new_key_prefix is None else self.new_key_prefix + key
+ d[new_key] = self.transform(d[key])
+
+ return d
+
+
RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
@@ -1758,6 +1813,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd
MaskIntensityD = MaskIntensityDict = MaskIntensityd
SavitzkyGolaySmoothD = SavitzkyGolaySmoothDict = SavitzkyGolaySmoothd
+MedianSmoothD = MedianSmoothDict = MedianSmoothd
GaussianSmoothD = GaussianSmoothDict = GaussianSmoothd
RandGaussianSmoothD = RandGaussianSmoothDict = RandGaussianSmoothd
GaussianSharpenD = GaussianSharpenDict = GaussianSharpend
@@ -1771,3 +1827,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized
RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled
ForegroundMaskD = ForegroundMaskDict = ForegroundMaskd
+ComputeHoVerMapsD = ComputeHoVerMapsDict = ComputeHoVerMapsd
diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py
index dc43475b631..3cfb2b19534 100644
--- a/monai/transforms/io/array.py
+++ b/monai/transforms/io/array.py
@@ -99,6 +99,10 @@ class LoadImage(Transform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (nrrd -> NrrdReader), (DICOM file -> ITKReader).
+ Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
+ loading the array because the `HW` definition for non-medical specific file formats is different
+ from other common medical packages.
+
See also:
- tutorial: https://github.com/Project-MONAI/tutorials/blob/master/modules/load_medical_images.ipynb
@@ -274,7 +278,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
img = EnsureChannelFirst()(img)
if self.image_only:
return img
- return img, img.meta # for compatibility purpose
+ return img, img.meta if isinstance(img, MetaTensor) else meta_data
class SaveImage(Transform):
diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py
index 761c891e85b..a1f088b4980 100644
--- a/monai/transforms/io/dictionary.py
+++ b/monai/transforms/io/dictionary.py
@@ -51,6 +51,10 @@ class LoadImaged(MapTransform):
- Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader),
(npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader).
+ Please note that for png, jpg, bmp, and other 2D formats, readers often swap axis 0 and 1 after
+ loading the array because the `HW` definition for non-medical specific file formats is different
+ from other common medical packages.
+
Note:
- If `reader` is specified, the loader will attempt to use the specified readers and the default supported
diff --git a/monai/transforms/meta_utility/dictionary.py b/monai/transforms/meta_utility/dictionary.py
index 90a6666b952..bef228f4230 100644
--- a/monai/transforms/meta_utility/dictionary.py
+++ b/monai/transforms/meta_utility/dictionary.py
@@ -45,7 +45,7 @@ class FromMetaTensord(MapTransform, InvertibleTransform):
have the form `{"a": torch.Tensor, "a_meta_dict": dict, "a_transforms": list, "b": ...}`.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]
def __init__(
self, keys: KeysCollection, data_type: Union[Sequence[str], str] = "tensor", allow_missing_keys: bool = False
@@ -92,7 +92,7 @@ class ToMetaTensord(MapTransform, InvertibleTransform):
have the form `{"a": MetaTensor, "b": MetaTensor}`.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py
index aab4aaa12fa..34709daa422 100644
--- a/monai/transforms/post/array.py
+++ b/monai/transforms/post/array.py
@@ -14,16 +14,17 @@
"""
import warnings
-from typing import Callable, Iterable, Optional, Sequence, Union
+from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
import numpy as np
import torch
+import torch.nn.functional as F
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks import one_hot
-from monai.networks.layers import GaussianFilter, apply_filter
+from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Transform
from monai.transforms.utils import (
@@ -31,16 +32,10 @@
fill_holes,
get_largest_connected_component_mask,
get_unique_labels,
+ remove_small_objects,
)
from monai.transforms.utils_pytorch_numpy_unification import unravel_index
-from monai.utils import (
- TransformBackends,
- convert_data_type,
- convert_to_tensor,
- deprecated_arg,
- ensure_tuple,
- look_up_option,
-)
+from monai.utils import TransformBackends, convert_data_type, convert_to_tensor, ensure_tuple, look_up_option
from monai.utils.type_conversion import convert_to_dst_type
__all__ = [
@@ -48,10 +43,12 @@
"AsDiscrete",
"FillHoles",
"KeepLargestConnectedComponent",
+ "RemoveSmallObjects",
"LabelFilter",
"LabelToContour",
"MeanEnsemble",
"ProbNMS",
+ "SobelGradients",
"VoteEnsemble",
"Invert",
]
@@ -59,7 +56,7 @@
class Activations(Transform):
"""
- Add activation operations to the model output, typically `Sigmoid` or `Softmax`.
+ Activation operations, typically `Sigmoid` or `Softmax`.
Args:
sigmoid: whether to execute sigmoid function on model output before transform.
@@ -68,6 +65,8 @@ class Activations(Transform):
Defaults to ``False``.
other: callable function to execute other activation layers, for example:
`other = lambda x: torch.tanh(x)`. Defaults to ``None``.
+ kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).
+ Defaults to ``dim=0``, unrecognized parameters will be ignored.
Raises:
TypeError: When ``other`` is not an ``Optional[Callable]``.
@@ -76,9 +75,12 @@ class Activations(Transform):
backend = [TransformBackends.TORCH]
- def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None) -> None:
+ def __init__(
+ self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None, **kwargs
+ ) -> None:
self.sigmoid = sigmoid
self.softmax = softmax
+ self.kwargs = kwargs
if other is not None and not callable(other):
raise TypeError(f"other must be None or callable but is {type(other).__name__}.")
self.other = other
@@ -116,7 +118,7 @@ def __call__(
if sigmoid or self.sigmoid:
img_t = torch.sigmoid(img_t)
if softmax or self.softmax:
- img_t = torch.softmax(img_t, dim=0)
+ img_t = torch.softmax(img_t, dim=self.kwargs.get("dim", 0))
act_func = self.other if other is None else other
if act_func is not None:
@@ -127,12 +129,11 @@ def __call__(
class AsDiscrete(Transform):
"""
- Execute after model forward to transform model output to discrete values.
- It can complete below operations:
+ Convert the input tensor/array into discrete values, possible operations are:
- - execute `argmax` for input logits values.
- - threshold input value to 0.0 or 1.0.
- - convert input value to One-Hot format.
+ - `argmax`.
+ - threshold input value to binary values.
+ - convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes).
- round the value to the closest integer.
Args:
@@ -144,6 +145,9 @@ class AsDiscrete(Transform):
Defaults to ``None``.
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].
+ kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.
+ currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
+ These default to ``0``, ``True``, ``torch.float`` respectively.
Example:
@@ -159,54 +163,26 @@ class AsDiscrete(Transform):
>>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
# [[[0.0, 0.0]], [[1.0, 1.0]]]
- .. deprecated:: 0.6.0
- ``n_classes`` is deprecated, use ``to_onehot`` instead.
-
- .. deprecated:: 0.7.0
- ``num_classes`` is deprecated, use ``to_onehot`` instead.
- ``logit_thresh`` is deprecated, use ``threshold`` instead.
- ``threshold_values`` is deprecated, use ``threshold`` instead.
-
"""
backend = [TransformBackends.TORCH]
- @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
- @deprecated_arg(
- name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
- )
def __init__(
self,
argmax: bool = False,
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
- n_classes: Optional[int] = None, # deprecated
- num_classes: Optional[int] = None, # deprecated
- logit_thresh: float = 0.5, # deprecated
- threshold_values: Optional[bool] = False, # deprecated
+ **kwargs,
) -> None:
self.argmax = argmax
if isinstance(to_onehot, bool): # for backward compatibility
- warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
- to_onehot = num_classes if to_onehot else None
+ raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
self.to_onehot = to_onehot
-
- if isinstance(threshold, bool): # for backward compatibility
- warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
- threshold = logit_thresh if threshold else None
self.threshold = threshold
-
self.rounding = rounding
+ self.kwargs = kwargs
- @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
- @deprecated_arg(
- name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
- )
def __call__(
self,
img: NdarrayOrTensor,
@@ -214,10 +190,6 @@ def __call__(
to_onehot: Optional[int] = None,
threshold: Optional[float] = None,
rounding: Optional[str] = None,
- n_classes: Optional[int] = None, # deprecated
- num_classes: Optional[int] = None, # deprecated
- logit_thresh: Optional[float] = None, # deprecated
- threshold_values: Optional[bool] = None, # deprecated
) -> NdarrayOrTensor:
"""
Args:
@@ -232,31 +204,21 @@ def __call__(
rounding: if not None, round the data according to the specified option,
available options: ["torchrounding"].
- .. deprecated:: 0.6.0
- ``n_classes`` is deprecated, use ``to_onehot`` instead.
-
- .. deprecated:: 0.7.0
- ``num_classes`` is deprecated, use ``to_onehot`` instead.
- ``logit_thresh`` is deprecated, use ``threshold`` instead.
- ``threshold_values`` is deprecated, use ``threshold`` instead.
-
"""
if isinstance(to_onehot, bool):
- warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
- to_onehot = num_classes if to_onehot else None
- if isinstance(threshold, bool):
- warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
- threshold = logit_thresh if threshold else None
+ raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor)
if argmax or self.argmax:
- img_t = torch.argmax(img_t, dim=0, keepdim=True)
+ img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))
to_onehot = self.to_onehot if to_onehot is None else to_onehot
if to_onehot is not None:
if not isinstance(to_onehot, int):
- raise AssertionError("the number of classes for One-Hot must be an integer.")
- img_t = one_hot(img_t, num_classes=to_onehot, dim=0)
+ raise ValueError("the number of classes for One-Hot must be an integer.")
+ img_t = one_hot(
+ img_t, num_classes=to_onehot, dim=self.kwargs.get("dim", 0), dtype=self.kwargs.get("dtype", torch.float)
+ )
threshold = self.threshold if threshold is None else threshold
if threshold is not None:
@@ -267,7 +229,7 @@ def __call__(
look_up_option(rounding, ["torchrounding"])
img_t = torch.round(img_t)
- img, *_ = convert_to_dst_type(img_t, img, dtype=torch.float)
+ img, *_ = convert_to_dst_type(img_t, img, dtype=self.kwargs.get("dtype", torch.float))
return img
@@ -315,7 +277,7 @@ class KeepLargestConnectedComponent(Transform):
"""
- backend = [TransformBackends.NUMPY]
+ backend = [TransformBackends.NUMPY, TransformBackends.CUPY]
def __init__(
self,
@@ -323,6 +285,7 @@ def __init__(
is_onehot: Optional[bool] = None,
independent: bool = True,
connectivity: Optional[int] = None,
+ num_components: int = 1,
) -> None:
"""
Args:
@@ -340,6 +303,7 @@ def __init__(
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used. for more details:
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
+ num_components: The number of largest components to preserve.
"""
super().__init__()
@@ -347,6 +311,7 @@ def __init__(
self.is_onehot = is_onehot
self.independent = independent
self.connectivity = connectivity
+ self.num_components = num_components
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
@@ -366,7 +331,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.independent:
for i in applied_labels:
foreground = img_[i] > 0 if is_onehot else img_[0] == i
- mask = get_largest_connected_component_mask(foreground, self.connectivity)
+ mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
if is_onehot:
img_[i][foreground != mask] = 0
else:
@@ -375,18 +340,55 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if not is_onehot: # not one-hot, union of labels
labels, *_ = convert_to_dst_type(applied_labels, dst=img_, wrap_sequence=True)
foreground = (img_[..., None] == labels).any(-1)[0]
- mask = get_largest_connected_component_mask(foreground, self.connectivity)
+ mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
img_[0][foreground != mask] = 0
return convert_to_dst_type(img_, dst=img)[0]
# one-hot, union of labels
foreground = (img_[applied_labels, ...] == 1).any(0)
- mask = get_largest_connected_component_mask(foreground, self.connectivity)
+ mask = get_largest_connected_component_mask(foreground, self.connectivity, self.num_components)
for i in applied_labels:
img_[i][foreground != mask] = 0
return convert_to_dst_type(img_, dst=img)[0]
-class LabelFilter:
+class RemoveSmallObjects(Transform):
+ """
+ Use `skimage.morphology.remove_small_objects` to remove small objects from images.
+ See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.
+
+ Data should be one-hotted.
+
+ Args:
+ min_size: objects smaller than this size are removed.
+ connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
+ Accepted values are ranging from 1 to input.ndim. If ``None``, a full
+ connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
+ documentation.
+ independent_channels: Whether or not to consider channels as independent. If true, then
+ conjoining islands from different labels will be removed if they are below the threshold.
+ If false, the overall size islands made from all non-background voxels will be used.
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, min_size: int = 64, connectivity: int = 1, independent_channels: bool = True) -> None:
+ self.min_size = min_size
+ self.connectivity = connectivity
+ self.independent_channels = independent_channels
+
+ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Data
+ should be one-hotted.
+
+ Returns:
+ An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
+ """
+ return remove_small_objects(img, self.min_size, self.connectivity, self.independent_channels)
+
+
+class LabelFilter(Transform):
"""
This transform filters out labels and can be used as a processing step to view only certain labels.
@@ -703,7 +705,7 @@ class ProbNMS(Transform):
prob_threshold: the probability threshold, the function will stop searching if
the highest probability is no larger than the threshold. The value should be
no less than 0.0. Defaults to 0.5.
- box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability.
+ box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.
It can be an integer that defines the size of a square or cube,
or a list containing different values for each dimensions. Defaults to 48.
@@ -718,7 +720,7 @@ class ProbNMS(Transform):
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.NUMPY]
def __init__(
self,
@@ -780,6 +782,8 @@ class Invert(Transform):
Utility transform to automatically invert the previously applied transforms.
"""
+ backend = [TransformBackends.TORCH]
+
def __init__(
self,
transform: Optional[InvertibleTransform] = None,
@@ -815,3 +819,113 @@ def __call__(self, data):
inverted = self.transform.inverse(data)
inverted = self.post_func(inverted.to(self.device))
return inverted
+
+
+class SobelGradients(Transform):
+ """Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]).
+
+ Args:
+ kernel_size: the size of the Sobel kernel. Defaults to 3.
+ spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
+ along each of the provide axis. By default it calculate the gradient for all spatial axes.
+ normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
+ normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
+ padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
+ Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ See ``torch.nn.Conv1d()`` for more information.
+ dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
+
+ """
+
+ backend = [TransformBackends.TORCH]
+
+ def __init__(
+ self,
+ kernel_size: int = 3,
+ spatial_axes: Optional[Union[Sequence[int], int]] = None,
+ normalize_kernels: bool = True,
+ normalize_gradients: bool = False,
+ padding_mode: str = "reflect",
+ dtype: torch.dtype = torch.float32,
+ ) -> None:
+ super().__init__()
+ self.padding = padding_mode
+ self.spatial_axes = spatial_axes
+ self.normalize_kernels = normalize_kernels
+ self.normalize_gradients = normalize_gradients
+ self.kernel_diff, self.kernel_smooth = self._get_kernel(kernel_size, dtype)
+
+ def _get_kernel(self, size, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
+ if size < 3:
+ raise ValueError(f"Sobel kernel size should be at least three. {size} was given.")
+ if size % 2 == 0:
+ raise ValueError(f"Sobel kernel size should be an odd number. {size} was given.")
+
+ kernel_diff = torch.tensor([[[-1, 0, 1]]], dtype=dtype)
+ kernel_smooth = torch.tensor([[[1, 2, 1]]], dtype=dtype)
+ kernel_expansion = torch.tensor([[[1, 2, 1]]], dtype=dtype)
+
+ if self.normalize_kernels:
+ if not dtype.is_floating_point:
+ raise ValueError(
+ f"`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. {dtype} was given."
+ )
+ kernel_diff /= 2.0
+ kernel_smooth /= 4.0
+ kernel_expansion /= 4.0
+
+ # Expand the kernel to larger size than 3
+ expand = (size - 3) // 2
+ for _ in range(expand):
+ kernel_diff = F.conv1d(kernel_diff, kernel_expansion, padding=2)
+ kernel_smooth = F.conv1d(kernel_smooth, kernel_expansion, padding=2)
+
+ return kernel_diff.squeeze(), kernel_smooth.squeeze()
+
+ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
+ image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
+
+ # Check/set spatial axes
+ n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension
+ valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))
+
+ # Check gradient axes to be valid
+ if self.spatial_axes is None:
+ spatial_axes = list(range(n_spatial_dims))
+ else:
+ invalid_axis = set(ensure_tuple(self.spatial_axes)) - set(valid_spatial_axes)
+ if invalid_axis:
+ raise ValueError(
+ f"The provide axes to calculate gradient is not valid: {invalid_axis}. "
+ f"The image has {n_spatial_dims} spatial dimensions so it should be: {valid_spatial_axes}."
+ )
+ spatial_axes = [ax % n_spatial_dims if ax < 0 else ax for ax in ensure_tuple(self.spatial_axes)]
+
+ # Add batch dimension for separable_filtering
+ image_tensor = image_tensor.unsqueeze(0)
+
+ # Get the Sobel kernels
+ kernel_diff = self.kernel_diff.to(image_tensor.device)
+ kernel_smooth = self.kernel_smooth.to(image_tensor.device)
+
+ # Calculate gradient
+ grad_list = []
+ for ax in spatial_axes:
+ kernels = [kernel_smooth] * n_spatial_dims
+ kernels[ax - 1] = kernel_diff
+ grad = separable_filtering(image_tensor, kernels, mode=self.padding)
+ if self.normalize_gradients:
+ grad_min = grad.min()
+ if grad_min != grad.max():
+ grad -= grad_min
+ grad_max = grad.max()
+ if grad_max > 0:
+ grad /= grad_max
+ grad_list.append(grad)
+
+ grads = torch.cat(grad_list, dim=1)
+
+ # Remove batch dimension and convert the gradient type to be the same as input image
+ grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]
+
+ return grads
diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py
index 67e19384549..78d84a0bd1f 100644
--- a/monai/transforms/post/dictionary.py
+++ b/monai/transforms/post/dictionary.py
@@ -19,6 +19,7 @@
from copy import deepcopy
from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Optional, Sequence, Union
+import numpy as np
import torch
from monai import config
@@ -35,12 +36,14 @@
LabelToContour,
MeanEnsemble,
ProbNMS,
+ RemoveSmallObjects,
+ SobelGradients,
VoteEnsemble,
)
from monai.transforms.transform import MapTransform
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_applied_interp_mode
-from monai.utils import PostFix, convert_to_tensor, deprecated_arg, ensure_tuple, ensure_tuple_rep
+from monai.utils import PostFix, convert_to_tensor, ensure_tuple, ensure_tuple_rep
__all__ = [
"ActivationsD",
@@ -61,6 +64,9 @@
"KeepLargestConnectedComponentD",
"KeepLargestConnectedComponentDict",
"KeepLargestConnectedComponentd",
+ "RemoveSmallObjectsD",
+ "RemoveSmallObjectsDict",
+ "RemoveSmallObjectsd",
"LabelFilterD",
"LabelFilterDict",
"LabelFilterd",
@@ -99,6 +105,7 @@ def __init__(
softmax: Union[Sequence[bool], bool] = False,
other: Optional[Union[Sequence[Callable], Callable]] = None,
allow_missing_keys: bool = False,
+ **kwargs,
) -> None:
"""
Args:
@@ -112,6 +119,8 @@ def __init__(
for example: `other = torch.tanh`. it also can be a sequence of Callable, each
element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
+ kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).
+ Defaults to ``dim=0``, unrecognized parameters will be ignored.
"""
super().__init__(keys, allow_missing_keys)
@@ -119,6 +128,7 @@ def __init__(
self.softmax = ensure_tuple_rep(softmax, len(self.keys))
self.other = ensure_tuple_rep(other, len(self.keys))
self.converter = Activations()
+ self.converter.kwargs = kwargs
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
@@ -134,12 +144,6 @@ class AsDiscreted(MapTransform):
backend = AsDiscrete.backend
- @deprecated_arg(name="n_classes", new_name="num_classes", since="0.6", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("num_classes", since="0.7", msg_suffix="please use `to_onehot` instead.")
- @deprecated_arg("logit_thresh", since="0.7", msg_suffix="please use `threshold` instead.")
- @deprecated_arg(
- name="threshold_values", new_name="threshold", since="0.7", msg_suffix="please use `threshold` instead."
- )
def __init__(
self,
keys: KeysCollection,
@@ -148,10 +152,7 @@ def __init__(
threshold: Union[Sequence[Optional[float]], Optional[float]] = None,
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
allow_missing_keys: bool = False,
- n_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
- num_classes: Optional[Union[Sequence[int], int]] = None, # deprecated
- logit_thresh: Union[Sequence[float], float] = 0.5, # deprecated
- threshold_values: Union[Sequence[bool], bool] = False, # deprecated
+ **kwargs,
) -> None:
"""
Args:
@@ -167,40 +168,28 @@ def __init__(
available options: ["torchrounding"]. it also can be a sequence of str or None,
each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
-
- .. deprecated:: 0.6.0
- ``n_classes`` is deprecated, use ``to_onehot`` instead.
-
- .. deprecated:: 0.7.0
- ``num_classes`` is deprecated, use ``to_onehot`` instead.
- ``logit_thresh`` is deprecated, use ``threshold`` instead.
- ``threshold_values`` is deprecated, use ``threshold`` instead.
+ kwargs: additional parameters to ``AsDiscrete``.
+ ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
+ These default to ``0``, ``True``, ``torch.float`` respectively.
"""
super().__init__(keys, allow_missing_keys)
self.argmax = ensure_tuple_rep(argmax, len(self.keys))
- to_onehot_ = ensure_tuple_rep(to_onehot, len(self.keys))
- num_classes = ensure_tuple_rep(num_classes, len(self.keys))
self.to_onehot = []
- for flag, val in zip(to_onehot_, num_classes):
+ for flag in ensure_tuple_rep(to_onehot, len(self.keys)):
if isinstance(flag, bool):
- warnings.warn("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
- self.to_onehot.append(val if flag else None)
- else:
- self.to_onehot.append(flag)
+ raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
+ self.to_onehot.append(flag)
- threshold_ = ensure_tuple_rep(threshold, len(self.keys))
- logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
self.threshold = []
- for flag, val in zip(threshold_, logit_thresh):
+ for flag in ensure_tuple_rep(threshold, len(self.keys)):
if isinstance(flag, bool):
- warnings.warn("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
- self.threshold.append(val if flag else None)
- else:
- self.threshold.append(flag)
+ raise ValueError("`threshold_values=True/False` is deprecated, please use `threshold=value` instead.")
+ self.threshold.append(flag)
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
self.converter = AsDiscrete()
+ self.converter.kwargs = kwargs
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
@@ -225,6 +214,7 @@ def __init__(
is_onehot: Optional[bool] = None,
independent: bool = True,
connectivity: Optional[int] = None,
+ num_components: int = 1,
allow_missing_keys: bool = False,
) -> None:
"""
@@ -245,12 +235,17 @@ def __init__(
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used. for more details:
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
+ num_components: The number of largest components to preserve.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = KeepLargestConnectedComponent(
- applied_labels=applied_labels, is_onehot=is_onehot, independent=independent, connectivity=connectivity
+ applied_labels=applied_labels,
+ is_onehot=is_onehot,
+ independent=independent,
+ connectivity=connectivity,
+ num_components=num_components,
)
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
@@ -260,6 +255,41 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d
+class RemoveSmallObjectsd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.RemoveSmallObjectsd`.
+
+ Args:
+ min_size: objects smaller than this size are removed.
+ connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
+ Accepted values are ranging from 1 to input.ndim. If ``None``, a full
+ connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
+ documentation.
+ independent_channels: Whether or not to consider channels as independent. If true, then
+ conjoining islands from different labels will be removed if they are below the threshold.
+ If false, the overall size islands made from all non-background voxels will be used.
+ """
+
+ backend = RemoveSmallObjects.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ min_size: int = 64,
+ connectivity: int = 1,
+ independent_channels: bool = True,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.converter = RemoveSmallObjects(min_size, connectivity, independent_channels)
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.converter(d[key])
+ return d
+
+
class LabelFilterd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`.
@@ -477,7 +507,7 @@ class ProbNMSd(MapTransform):
prob_threshold: the probability threshold, the function will stop searching if
the highest probability is no larger than the threshold. The value should be
no less than 0.0. Defaults to 0.5.
- box_size: the box size (in pixel) to be removed around the the pixel with the maximum probability.
+ box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.
It can be an integer that defines the size of a square or cube,
or a list containing different values for each dimensions. Defaults to 48.
@@ -668,7 +698,12 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
inverted_data = self._totensor(inverted[orig_key])
else:
inverted_data = inverted[orig_key]
- d[key] = post_func(inverted_data.to(device))
+ if isinstance(inverted_data, np.ndarray) and torch.device(device).type != "cpu":
+ raise ValueError("Inverted data with type of 'numpy.ndarray' do not support GPU.")
+ elif isinstance(inverted_data, torch.Tensor):
+ d[key] = post_func(inverted_data.to(device))
+ else:
+ d[key] = post_func(inverted_data)
# save the invertd applied_operations if it's in the source dict
if InvertibleTransform.trace_key(orig_key) in d:
d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations
@@ -758,11 +793,68 @@ def get_saver(self):
return self.saver
+class SobelGradientsd(MapTransform):
+ """Calculate Sobel horizontal and vertical gradients of a grayscale image.
+
+ Args:
+ keys: keys of the corresponding items to model output.
+ kernel_size: the size of the Sobel kernel. Defaults to 3.
+ spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
+ along each of the provide axis. By default it calculate the gradient for all spatial axes.
+ normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
+ normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
+ padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
+ Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
+ See ``torch.nn.Conv1d()`` for more information.
+ dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
+ new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
+ key intact. By default not prefix is set and the corresponding array to the key will be replaced.
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+
+ backend = SobelGradients.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ kernel_size: int = 3,
+ spatial_axes: Optional[Union[Sequence[int], int]] = None,
+ normalize_kernels: bool = True,
+ normalize_gradients: bool = False,
+ padding_mode: str = "reflect",
+ dtype: torch.dtype = torch.float32,
+ new_key_prefix: Optional[str] = None,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.transform = SobelGradients(
+ kernel_size=kernel_size,
+ spatial_axes=spatial_axes,
+ normalize_kernels=normalize_kernels,
+ normalize_gradients=normalize_gradients,
+ padding_mode=padding_mode,
+ dtype=dtype,
+ )
+ self.new_key_prefix = new_key_prefix
+ self.kernel_diff = self.transform.kernel_diff
+ self.kernel_smooth = self.transform.kernel_smooth
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ new_key = key if self.new_key_prefix is None else self.new_key_prefix + key
+ d[new_key] = self.transform(d[key])
+
+ return d
+
+
ActivationsD = ActivationsDict = Activationsd
AsDiscreteD = AsDiscreteDict = AsDiscreted
FillHolesD = FillHolesDict = FillHolesd
InvertD = InvertDict = Invertd
KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd
+RemoveSmallObjectsD = RemoveSmallObjectsDict = RemoveSmallObjectsd
LabelFilterD = LabelFilterDict = LabelFilterd
LabelToContourD = LabelToContourDict = LabelToContourd
MeanEnsembleD = MeanEnsembleDict = MeanEnsembled
@@ -770,3 +862,4 @@ def get_saver(self):
SaveClassificationD = SaveClassificationDict = SaveClassificationd
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
EnsembleD = EnsembleDict = Ensembled
+SobelGradientsD = SobelGradientsDict = SobelGradientsd
diff --git a/monai/transforms/signal/__init__.py b/monai/transforms/signal/__init__.py
new file mode 100644
index 00000000000..1e97f894078
--- /dev/null
+++ b/monai/transforms/signal/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/transforms/signal/array.py b/monai/transforms/signal/array.py
new file mode 100644
index 00000000000..07f29a3039f
--- /dev/null
+++ b/monai/transforms/signal/array.py
@@ -0,0 +1,457 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A collection of transforms for signal operations
+https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
+"""
+
+import warnings
+from typing import Any, Optional, Sequence
+
+import numpy as np
+import torch
+
+from monai.config.type_definitions import NdarrayOrTensor
+from monai.transforms.transform import RandomizableTransform, Transform
+from monai.transforms.utils import check_boundaries, paste, squarepulse
+from monai.utils import optional_import
+from monai.utils.enums import TransformBackends
+from monai.utils.type_conversion import convert_data_type, convert_to_tensor
+
+shift, has_shift = optional_import("scipy.ndimage.interpolation", name="shift")
+iirnotch, has_iirnotch = optional_import("scipy.signal", name="iirnotch")
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore", UserWarning) # project-monai/monai#5204
+ filtfilt, has_filtfilt = optional_import("torchaudio.functional", name="filtfilt")
+central_frequency, has_central_frequency = optional_import("pywt", name="central_frequency")
+cwt, has_cwt = optional_import("pywt", name="cwt")
+
+__all__ = [
+ "SignalRandDrop",
+ "SignalRandScale",
+ "SignalRandShift",
+ "SignalRandAddSine",
+ "SignalRandAddSquarePulse",
+ "SignalRandAddGaussianNoise",
+ "SignalRandAddSinePartial",
+ "SignalRandAddSquarePulsePartial",
+ "SignalFillEmpty",
+ "SignalRemoveFrequency",
+ "SignalContinuousWavelet",
+]
+
+
+class SignalRandShift(RandomizableTransform):
+ """
+ Apply a random shift on a signal
+ """
+
+ backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
+
+ def __init__(
+ self, mode: Optional[str] = "wrap", filling: Optional[float] = 0.0, boundaries: Sequence[float] = (-1.0, 1.0)
+ ) -> None:
+ """
+ Args:
+ mode: define how the extension of the input array is done beyond its boundaries, see for more details :
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.shift.html.
+ filling: value to fill past edges of input if mode is ‘constant’. Default is 0.0. see for mode details :
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.shift.html.
+ boundaries: list defining lower and upper boundaries for the signal shift, default : ``[-1.0, 1.0]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.filling = filling
+ self.mode = mode
+ self.boundaries = boundaries
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to be shifted
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ length = signal.shape[1]
+ shift_idx = round(self.magnitude * length)
+ sig = convert_data_type(signal, np.ndarray)[0]
+ signal = convert_to_tensor(shift(input=sig, mode=self.mode, shift=shift_idx, cval=self.filling))
+ return signal
+
+
+class SignalRandScale(RandomizableTransform):
+ """
+ Apply a random rescaling on a signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, boundaries: Sequence[float] = (-1.0, 1.0)) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the signal scaling, default : ``[-1.0, 1.0]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to be scaled
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ signal = convert_to_tensor(self.magnitude * signal)
+
+ return signal
+
+
+class SignalRandDrop(RandomizableTransform):
+ """
+ Randomly drop a portion of a signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, boundaries: Sequence[float] = (0.0, 1.0)) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the signal drop,
+ lower and upper values need to be positive default : ``[0.0, 1.0]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to be dropped
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+
+ length = signal.shape[len(signal.shape) - 1]
+ mask = torch.zeros(round(self.magnitude * length))
+ trange = torch.arange(length)
+ loc = trange[torch.randint(0, trange.size(0), (1,))]
+ signal = convert_to_tensor(paste(signal, mask, (loc,)))
+
+ return signal
+
+
+class SignalRandAddSine(RandomizableTransform):
+ """
+ Add a random sinusoidal signal to the input signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, boundaries: Sequence[float] = (0.1, 0.3), frequencies: Sequence[float] = (0.001, 0.02)) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the sinusoidal magnitude,
+ lower and upper values need to be positive ,default : ``[0.1, 0.3]``
+ frequencies: list defining lower and upper frequencies for sinusoidal
+ signal generation ,default : ``[0.001, 0.02]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+ self.frequencies = frequencies
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to which sinusoidal signal will be added
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])
+
+ length = signal.shape[1]
+
+ time = np.arange(0, length, 1)
+ data = convert_to_tensor(self.freqs * time)
+ sine = self.magnitude * torch.sin(data)
+ signal = convert_to_tensor(signal) + sine
+
+ return signal
+
+
+class SignalRandAddSquarePulse(RandomizableTransform):
+ """
+ Add a random square pulse signal to the input signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, boundaries: Sequence[float] = (0.01, 0.2), frequencies: Sequence[float] = (0.001, 0.02)) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the square pulse magnitude,
+ lower and upper values need to be positive , default : ``[0.01, 0.2]``
+ frequencies: list defining lower and upper frequencies for the square pulse
+ signal generation , default : ``[0.001, 0.02]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+ self.frequencies = frequencies
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to which square pulse will be added
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])
+
+ length = signal.shape[1]
+
+ time = np.arange(0, length, 1)
+ squaredpulse = self.magnitude * squarepulse(self.freqs * time)
+ signal = convert_to_tensor(signal) + squaredpulse
+
+ return signal
+
+
+class SignalRandAddSinePartial(RandomizableTransform):
+ """
+ Add a random partial sinusoidal signal to the input signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ boundaries: Sequence[float] = (0.1, 0.3),
+ frequencies: Sequence[float] = (0.001, 0.02),
+ fraction: Sequence[float] = (0.01, 0.2),
+ ) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the sinusoidal magnitude,
+ lower and upper values need to be positive , default : ``[0.1, 0.3]``
+ frequencies: list defining lower and upper frequencies for sinusoidal
+ signal generation , default : ``[0.001, 0.02]``
+ fraction: list defining lower and upper boundaries for partial signal generation
+ default : ``[0.01, 0.2]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+ self.frequencies = frequencies
+ self.fraction = fraction
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to which a partial sinusoidal signal
+ will be added
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ self.fracs = self.R.uniform(low=self.fraction[0], high=self.fraction[1])
+ self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])
+
+ length = signal.shape[len(signal.shape) - 1]
+
+ time_partial = np.arange(0, round(self.fracs * length), 1)
+ data = convert_to_tensor(self.freqs * time_partial)
+ sine_partial = self.magnitude * torch.sin(data)
+
+ loc = np.random.choice(range(length))
+ signal = paste(signal, sine_partial, (loc,))
+
+ return signal
+
+
+class SignalRandAddGaussianNoise(RandomizableTransform):
+ """
+ Add a random gaussian noise to the input signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, boundaries: Sequence[float] = (0.001, 0.02)) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the signal magnitude,
+ default : ``[0.001,0.02]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to which gaussian noise will be added
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ length = signal.shape[1]
+ gaussiannoise = self.magnitude * torch.randn(length)
+
+ signal = convert_to_tensor(signal) + gaussiannoise
+
+ return signal
+
+
+class SignalRandAddSquarePulsePartial(RandomizableTransform):
+ """
+ Add a random partial square pulse to a signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ boundaries: Sequence[float] = (0.01, 0.2),
+ frequencies: Sequence[float] = (0.001, 0.02),
+ fraction: Sequence[float] = (0.01, 0.2),
+ ) -> None:
+ """
+ Args:
+ boundaries: list defining lower and upper boundaries for the square pulse magnitude,
+ lower and upper values need to be positive , default : ``[0.01, 0.2]``
+ frequencies: list defining lower and upper frequencies for square pulse
+ signal generation example : ``[0.001, 0.02]``
+ fraction: list defining lower and upper boundaries for partial square pulse generation
+ default: ``[0.01, 0.2]``
+ """
+ super().__init__()
+ check_boundaries(boundaries)
+ self.boundaries = boundaries
+ self.frequencies = frequencies
+ self.fraction = fraction
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: input 1 dimension signal to which a partial square pulse will be added
+ """
+ self.randomize(None)
+ self.magnitude = self.R.uniform(low=self.boundaries[0], high=self.boundaries[1])
+ self.fracs = self.R.uniform(low=self.fraction[0], high=self.fraction[1])
+ self.freqs = self.R.uniform(low=self.frequencies[0], high=self.frequencies[1])
+
+ length = signal.shape[len(signal.shape) - 1]
+
+ time_partial = np.arange(0, round(self.fracs * length), 1)
+ squaredpulse_partial = self.magnitude * squarepulse(self.freqs * time_partial)
+
+ loc = np.random.choice(range(length))
+ signal = paste(signal, squaredpulse_partial, (loc,))
+
+ return signal
+
+
+class SignalFillEmpty(Transform):
+ """
+ replace empty part of a signal (NaN)
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, replacement: float = 0.0) -> None:
+ """
+ Args:
+ replacement: value to replace nan items in signal
+ """
+ super().__init__()
+ self.replacement = replacement
+
+ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Args:
+ signal: signal to be filled
+ """
+ signal = torch.nan_to_num(convert_to_tensor(signal), nan=self.replacement)
+ return signal
+
+
+class SignalRemoveFrequency(Transform):
+ """
+ Remove a frequency from a signal
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ frequency: Optional[float] = None,
+ quality_factor: Optional[float] = None,
+ sampling_freq: Optional[float] = None,
+ ) -> None:
+ """
+ Args:
+ frequency: frequency to be removed from the signal
+ quality_factor: quality factor for notch filter
+ see : https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirnotch.html
+ sampling_freq: sampling frequency of the input signal
+ """
+ super().__init__()
+ self.frequency = frequency
+ self.quality_factor = quality_factor
+ self.sampling_freq = sampling_freq
+
+ def __call__(self, signal: np.ndarray) -> Any:
+ """
+ Args:
+ signal: signal to be frequency removed
+ """
+ b_notch, a_notch = convert_to_tensor(
+ iirnotch(self.frequency, self.quality_factor, self.sampling_freq), dtype=torch.float
+ )
+ y_notched = filtfilt(convert_to_tensor(signal), a_notch, b_notch)
+
+ return y_notched
+
+
+class SignalContinuousWavelet(Transform):
+ """
+ Generate continuous wavelet transform of a signal
+ """
+
+ backend = [TransformBackends.NUMPY]
+
+ def __init__(self, type: str = "mexh", length: float = 125.0, frequency: float = 500.0) -> None:
+ """
+ Args:
+ type: mother wavelet type.
+ Available options are: {``"mexh"``, ``"morl"``, ``"cmorB-C"``, , ``"gausP"``}
+ see : https://pywavelets.readthedocs.io/en/latest/ref/cwt.html
+ length: expected length, default ``125.0``
+ frequency: signal frequency, default ``500.0``
+ """
+ super().__init__()
+ self.frequency = frequency
+ self.length = length
+ self.type = type
+
+ def __call__(self, signal: np.ndarray) -> Any:
+ """
+ Args:
+ signal: signal for which to generate continuous wavelet transform
+ """
+ mother_wavelet = self.type
+ spread = np.arange(1, self.length + 1, 1)
+ scales = central_frequency(mother_wavelet) * self.frequency / spread
+
+ coeffs, _ = cwt(signal, scales, mother_wavelet, 1.0 / self.frequency)
+
+ coeffs = np.transpose(coeffs, [1, 0, 2])
+
+ return coeffs
diff --git a/monai/transforms/smooth_field/array.py b/monai/transforms/smooth_field/array.py
index 953c589288d..13507339e10 100644
--- a/monai/transforms/smooth_field/array.py
+++ b/monai/transforms/smooth_field/array.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""Transforms using a smooth spatial field generated by interpolating from smaller randomized fields."""
from typing import Any, Optional, Sequence, Union
@@ -52,6 +51,8 @@ class SmoothField(Randomizable):
device: Pytorch device to define field on
"""
+ backend = [TransformBackends.TORCH]
+
def __init__(
self,
rand_size: Sequence[int],
@@ -160,7 +161,7 @@ class RandSmoothFieldAdjustContrast(RandomizableTransform):
device: Pytorch device to define field on
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(
self,
@@ -261,7 +262,7 @@ class RandSmoothFieldAdjustIntensity(RandomizableTransform):
device: Pytorch device to define field on
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(
self,
@@ -404,7 +405,7 @@ def __init__(
device=device,
)
- grid_space = spatial_size if spatial_size is not None else self.sfield.field.shape[2:]
+ grid_space = tuple(spatial_size) if spatial_size is not None else self.sfield.field.shape[2:]
grid_ranges = [torch.linspace(-1, 1, d) for d in grid_space]
grid = meshgrid_ij(*grid_ranges)
@@ -446,6 +447,7 @@ def __call__(
dgrid = self.grid + field.to(self.grid_dtype)
dgrid = moveaxis(dgrid, 1, -1) # type: ignore
+ dgrid = dgrid[..., list(range(dgrid.shape[-1] - 1, -1, -1))] # invert order of coordinates
img_t = convert_to_tensor(img[None], torch.float32, device)
diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py
index 48e00b9e4a0..08fb71edb46 100644
--- a/monai/transforms/smooth_field/dictionary.py
+++ b/monai/transforms/smooth_field/dictionary.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Any, Hashable, Mapping, Optional, Sequence, Union
import numpy as np
@@ -25,7 +24,6 @@
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, convert_to_tensor, ensure_tuple_rep
-from monai.utils.enums import TransformBackends
__all__ = [
"RandSmoothFieldAdjustContrastd",
@@ -60,7 +58,7 @@ class RandSmoothFieldAdjustContrastd(RandomizableTransform, MapTransform):
device: Pytorch device to define field on
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = RandSmoothFieldAdjustContrast.backend
def __init__(
self,
@@ -138,7 +136,7 @@ class RandSmoothFieldAdjustIntensityd(RandomizableTransform, MapTransform):
device: Pytorch device to define field on
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = RandSmoothFieldAdjustIntensity.backend
def __init__(
self,
@@ -219,7 +217,7 @@ class RandSmoothDeformd(RandomizableTransform, MapTransform):
device: Pytorch device to define field on
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = RandSmoothDeform.backend
def __init__(
self,
diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py
index a8e76e098b0..dcddefce3a4 100644
--- a/monai/transforms/spatial/array.py
+++ b/monai/transforms/spatial/array.py
@@ -15,6 +15,7 @@
import warnings
from copy import deepcopy
from enum import Enum
+from itertools import zip_longest
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
@@ -24,7 +25,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
-from monai.data.utils import AFFINE_TOL, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
+from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
from monai.networks.utils import meshgrid_ij, normalize_transform
from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop
@@ -42,13 +43,17 @@
map_spatial_axes,
scale_affine,
)
-from monai.transforms.utils_pytorch_numpy_unification import allclose, linalg_inv, moveaxis
+from monai.transforms.utils_pytorch_numpy_unification import allclose, linalg_inv, moveaxis, where
from monai.utils import (
GridSampleMode,
GridSamplePadMode,
InterpolateMode,
+ NdimageMode,
NumpyPadMode,
+ SplineMode,
+ convert_to_cupy,
convert_to_dst_type,
+ convert_to_numpy,
convert_to_tensor,
ensure_tuple,
ensure_tuple_rep,
@@ -59,12 +64,15 @@
pytorch_after,
)
from monai.utils.deprecate_utils import deprecated_arg
-from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends
+from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends, WSIPatchKeys
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
nib, has_nib = optional_import("nibabel")
+cupy, _ = optional_import("cupy")
+cupy_ndi, _ = optional_import("cupyx.scipy.ndimage")
+np_ndi, _ = optional_import("scipy.ndimage")
__all__ = [
"SpatialResample",
@@ -108,31 +116,32 @@ class SpatialResample(InvertibleTransform):
by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``.
"""
- backend = [TransformBackends.TORCH]
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY, TransformBackends.CUPY]
def __init__(
self,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
):
"""
Args:
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
- See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If ``None``, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
"""
self.mode = mode
self.padding_mode = padding_mode
@@ -152,7 +161,7 @@ def _post_process(
"""
Small fn to simplify returning data. If `MetaTensor`, update affine. Elif
tracking metadata is desired, create `MetaTensor` with affine. Else, return
- image as `torch.Tensor`. Output type is always `torch.float32`.
+ image as `torch.Tensor`. Output type is always `float32`.
Also append the transform to the stack.
"""
@@ -185,9 +194,9 @@ def __call__(
src_affine: Optional[NdarrayOrTensor] = None,
dst_affine: Optional[torch.Tensor] = None,
spatial_size: Optional[Union[Sequence[int], torch.Tensor, int]] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
- align_corners: Optional[bool] = False,
+ align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
) -> torch.Tensor:
"""
@@ -202,17 +211,21 @@ def __call__(
if `spatial_size` and `self.spatial_size` are not defined,
the transform will compute a spatial size automatically containing the previous field of view.
if `spatial_size` is ``-1`` are the transform will use the corresponding input img size.
- mode: {``"bilinear"``, ``"nearest"``}
- Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
+ Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"border"``.
+ Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ Defaults to ``None``, effectively using the value of `self.align_corners`.
dtype: data type for resampling computation. Defaults to ``self.dtype`` or
``np.float64`` (for best precision). If ``None``, use the data type of input data.
To be compatible with other modules, the output data type is always `float32`.
@@ -226,8 +239,8 @@ def __call__(
# get dtype as torch (e.g., torch.float64)
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
align_corners = self.align_corners if align_corners is None else align_corners
- mode = mode or self.mode
- padding_mode = padding_mode or self.padding_mode
+ mode = mode if mode is not None else self.mode
+ padding_mode = padding_mode if padding_mode is not None else self.padding_mode
original_spatial_shape = img.shape[1:]
src_affine_: torch.Tensor = img.affine if isinstance(img, MetaTensor) else torch.eye(4)
@@ -279,16 +292,13 @@ def __call__(
if additional_dims:
xform_shape = [-1] + in_spatial_size
img = img.reshape(xform_shape) # type: ignore
- if align_corners:
- _t_r = torch.eye(len(xform), dtype=xform.dtype, device=xform.device)
- for idx, d_dst in enumerate(spatial_size[:spatial_rank]):
- _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0
- xform = xform @ _t_r
- if not USE_COMPILED:
- _t_l = normalize_transform(
- in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore
- )[0]
- xform = _t_l @ xform
+ if isinstance(mode, int):
+ dst_xform_1 = normalize_transform(spatial_size, xform.device, xform.dtype, True, True)[0] # to (-1, 1)
+ if not align_corners:
+ norm = create_scale(spatial_rank, [(max(d, 2) - 1) / d for d in spatial_size], xform.device, "torch")
+ dst_xform_1 = norm.to(xform.dtype) @ dst_xform_1 # type: ignore # scaling (num_step - 1) / num_step
+ dst_xform_d = normalize_transform(spatial_size, xform.device, xform.dtype, align_corners, False)[0]
+ xform = xform @ torch.inverse(dst_xform_d) @ dst_xform_1
affine_xform = Affine(
affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype
)
@@ -352,9 +362,9 @@ def __call__(
img_dst: torch.Tensor,
src_meta: Optional[Dict] = None,
dst_meta: Optional[Dict] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
- align_corners: Optional[bool] = False,
+ align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
) -> torch.Tensor:
"""
@@ -369,17 +379,20 @@ def __call__(
specified, ``src_affine`` is assumed. If ``spatial_shape`` is not specified, spatial size is
automatically computed, containing the previous field of view. Defaults to ``None``.
See also: https://docs.monai.io/en/stable/transforms.html#spatialresample
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
- Defaults to ``False``.
+ Defaults to ``None``, effectively using the value of `self.align_corners`.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``self.dtype`` or
``np.float64`` (for best precision). If ``None``, use the data type of input data.
@@ -419,10 +432,14 @@ def __init__(
self,
pixdim: Union[Sequence[float], float, np.ndarray],
diagonal: bool = False,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
+ scale_extent: bool = False,
+ recompute_affine: bool = False,
+ min_pixdim: Union[Sequence[float], float, np.ndarray, None] = None,
+ max_pixdim: Union[Sequence[float], float, np.ndarray, None] = None,
image_only: bool = False,
) -> None:
"""
@@ -430,7 +447,8 @@ def __init__(
pixdim: output voxel spacing. if providing a single number, will use it for the first dimension.
items of the pixdim sequence map to the spatial dimensions of input image, if length
of pixdim sequence is longer than image spatial dimensions, will ignore the longer part,
- if shorter, will pad with `1.0`.
+ if shorter, will pad with the last value. For example, for 3D image if pixdim is [1.0, 2.0] it
+ will be padded to [1.0, 2.0, 2.0]
if the components of the `pixdim` are non-positive values, the transform will use the
corresponding components of the original pixdim, which is computed from the `affine`
matrix of input image.
@@ -445,30 +463,51 @@ def __init__(
If False, this transform preserves the axes orientation, orthogonal rotation and
translation components from the original affine. This option will not flip/swap axes
of the original data.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
+ scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
+ default False. The option is ignored if output spatial size is specified when calling this transform.
+ See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
+ should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.
+ recompute_affine: whether to recompute affine based on the output shape. The affine computed
+ analytically does not reflect the potential quantization errors in terms of the output shape.
+ Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.
+ min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this
+ value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
+ value of `pixdim`. Default to `None`.
+ max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
+ value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
+ value of `pixdim`. Default to `None`.
"""
self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
+ self.min_pixdim = np.array(ensure_tuple(min_pixdim), dtype=np.float64)
+ self.max_pixdim = np.array(ensure_tuple(max_pixdim), dtype=np.float64)
self.diagonal = diagonal
+ self.scale_extent = scale_extent
+ self.recompute_affine = recompute_affine
+
+ for mn, mx in zip(self.min_pixdim, self.max_pixdim):
+ if (not np.isnan(mn)) and (not np.isnan(mx)) and ((mx < mn) or (mn < 0)):
+ raise ValueError(f"min_pixdim {self.min_pixdim} must be positive, smaller than max {self.max_pixdim}.")
self.sp_resample = SpatialResample(
- mode=look_up_option(mode, GridSampleMode),
- padding_mode=look_up_option(padding_mode, GridSamplePadMode),
- align_corners=align_corners,
- dtype=dtype,
+ mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype
)
@deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.")
@@ -476,29 +515,38 @@ def __call__(
self,
data_array: torch.Tensor,
affine: Optional[NdarrayOrTensor] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
+ scale_extent: Optional[bool] = None,
output_spatial_shape: Optional[Union[Sequence[int], np.ndarray, int]] = None,
) -> torch.Tensor:
"""
Args:
data_array: in shape (num_channels, H[, W, ...]).
- mode: {``"bilinear"``, ``"nearest"``}
- Interpolation mode to calculate output values. Defaults to ``self.mode``.
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
+ Interpolation mode to calculate output values. Defaults to ``"self.mode"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``self.padding_mode``.
+ Padding mode for outside grid values. Defaults to ``"self.padding_mode"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ Defaults to ``None``, effectively using the value of `self.align_corners`.
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
+ scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
+ The option is ignored if output spatial size is specified when calling this transform.
+ See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
+ should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.
output_spatial_shape: specify the shape of the output data_array. This is typically useful for
the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization
error with the affine.
@@ -515,7 +563,6 @@ def __call__(
sr = len(original_spatial_shape)
if sr <= 0:
raise ValueError("data_array must have at least one spatial dimension.")
- input_affine: Optional[NdarrayOrTensor] = None
affine_: np.ndarray
if affine is not None:
warnings.warn("arg `affine` is deprecated, the affine of MetaTensor in data_array has higher priority.")
@@ -528,27 +575,45 @@ def __call__(
out_d = self.pixdim[:sr]
if out_d.size < sr:
- out_d = np.append(out_d, [1.0] * (sr - out_d.size))
+ out_d = np.append(out_d, [out_d[-1]] * (sr - out_d.size))
+
+ orig_d = affine_to_spacing(affine_, sr, out_d.dtype)
+ for idx, (_d, mn, mx) in enumerate(
+ zip_longest(orig_d, self.min_pixdim[:sr], self.max_pixdim[:sr], fillvalue=np.nan)
+ ):
+ target = out_d[idx]
+ mn = target if np.isnan(mn) else min(mn, target)
+ mx = target if np.isnan(mx) else max(mx, target)
+ if mn > mx:
+ raise ValueError(f"min_pixdim is larger than max_pixdim at dim {idx}: min {mn} max {mx} out {target}.")
+ out_d[idx] = _d if (mn - AFFINE_TOL) <= _d <= (mx + AFFINE_TOL) else target
+
+ if not align_corners and scale_extent:
+ warnings.warn("align_corners=False is not compatible with scale_extent=True.")
# compute output affine, shape and offset
new_affine = zoom_affine(affine_, out_d, diagonal=self.diagonal)
- output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine)
+ scale_extent = self.scale_extent if scale_extent is None else scale_extent
+ output_shape, offset = compute_shape_offset(data_array.shape[1:], affine_, new_affine, scale_extent)
new_affine[:sr, -1] = offset[:sr]
# convert to MetaTensor if necessary
data_array = convert_to_tensor(data_array, track_meta=get_track_meta())
- data_array.affine = torch.as_tensor(affine_) # type: ignore
+ if isinstance(data_array, MetaTensor):
+ data_array.affine = torch.as_tensor(affine_)
# we don't want to track the nested transform otherwise two will be appended
+ actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape
data_array = self.sp_resample(
data_array,
dst_affine=torch.as_tensor(new_affine),
- spatial_size=list(output_shape) if output_spatial_shape is None else output_spatial_shape,
+ spatial_size=actual_shape,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
dtype=dtype,
)
-
+ if self.recompute_affine and isinstance(data_array, MetaTensor):
+ data_array.affine = scale_affine(affine_, original_spatial_shape, actual_shape)
return data_array
def inverse(self, data: torch.Tensor) -> torch.Tensor:
@@ -842,10 +907,13 @@ def __call__(
scale = self.spatial_size / max(img_size)
spatial_size_ = tuple(int(round(s * scale)) for s in img_size)
+ original_sp_size = img.shape[1:]
+ _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode)
+ _align_corners = self.align_corners if align_corners is None else align_corners
if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired
- return convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore
+ img = convert_to_tensor(img, track_meta=get_track_meta())
- original_sp_size = img.shape[1:]
+ return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim)
img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False)
if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])):
@@ -862,25 +930,25 @@ def __call__(
img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False)
img = convert_to_tensor(img, track_meta=get_track_meta())
- _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode)
- _align_corners = self.align_corners if align_corners is None else align_corners
-
resized = torch.nn.functional.interpolate(
input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners
)
out, *_ = convert_to_dst_type(resized.squeeze(0), img)
+ return self._post_process(out, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim)
+
+ def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corners, ndim) -> torch.Tensor:
if get_track_meta():
- self.update_meta(out, original_sp_size, spatial_size_)
+ self.update_meta(img, orig_size, sp_size)
self.push_transform(
- out,
- orig_size=original_sp_size,
+ img,
+ orig_size=orig_size,
extra_info={
- "mode": _mode,
- "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE,
- "new_dim": len(original_sp_size) - input_ndim, # additional dims appended
+ "mode": mode,
+ "align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
+ "new_dim": len(orig_size) - ndim, # additional dims appended
},
)
- return out
+ return img
def update_meta(self, img, spatial_size, new_spatial_size):
affine = convert_to_tensor(img.affine, track_meta=False)
@@ -921,9 +989,9 @@ class Rotate(InvertibleTransform):
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- dtype: data type for resampling computation. Defaults to ``np.float32``.
+ dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
"""
backend = [TransformBackends.TORCH]
@@ -935,7 +1003,7 @@ def __init__(
mode: str = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
align_corners: bool = False,
- dtype: Union[DtypeLike, torch.dtype] = np.float32,
+ dtype: Union[DtypeLike, torch.dtype] = torch.float32,
) -> None:
self.angle = angle
self.keep_size = keep_size
@@ -967,7 +1035,7 @@ def __call__(
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
Raises:
ValueError: When ``img`` spatially is not one of [2D, 3D].
@@ -1348,9 +1416,9 @@ class RandRotate(RandomizableTransform, InvertibleTransform):
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- dtype: data type for resampling computation. Defaults to ``np.float32``.
+ dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
"""
backend = Rotate.backend
@@ -1420,7 +1488,7 @@ def __call__(
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation. Defaults to ``self.dtype``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
randomize: whether to execute `randomize()` function first, default to True.
"""
if randomize:
@@ -1437,7 +1505,7 @@ def __call__(
)
out = rotator(img)
else:
- out = convert_to_tensor(img, track_meta=get_track_meta())
+ out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
rot_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=rot_info)
@@ -1648,7 +1716,7 @@ def __call__(
self.randomize(img=img)
if not self._do_transform:
- out = convert_to_tensor(img, track_meta=get_track_meta())
+ out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
else:
out = Zoom(
self._zoom,
@@ -1691,28 +1759,23 @@ class AffineGrid(Transform):
pixel/voxel relative to the center of the input image. Defaults to no translation.
scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,
a tuple of 3 floats for 3D. Defaults to `1.0`.
- dtype: data type for the grid computation. Defaults to ``np.float32``.
+ dtype: data type for the grid computation. Defaults to ``float32``.
If ``None``, use the data type of input data (if `grid` is provided).
device: device on which the tensor will be allocated, if a new grid is generated.
affine: If applied, ignore the params (`rotate_params`, etc.) and use the
supplied matrix. Should be square with each side = num of image spatial
dimensions + 1.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
backend = [TransformBackends.TORCH]
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
rotate_params: Optional[Union[Sequence[float], float]] = None,
shear_params: Optional[Union[Sequence[float], float]] = None,
translate_params: Optional[Union[Sequence[float], float]] = None,
scale_params: Optional[Union[Sequence[float], float]] = None,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
dtype: DtypeLike = np.float32,
affine: Optional[NdarrayOrTensor] = None,
@@ -1780,14 +1843,12 @@ class RandAffineGrid(Randomizable, Transform):
backend = AffineGrid.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
rotate_range: RandRange = None,
shear_range: RandRange = None,
translate_range: RandRange = None,
scale_range: RandRange = None,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
) -> None:
"""
@@ -1822,9 +1883,6 @@ def __init__(
- :py:meth:`monai.transforms.utils.create_translate`
- :py:meth:`monai.transforms.utils.create_scale`
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
self.rotate_range = ensure_tuple(rotate_range)
self.shear_range = ensure_tuple(shear_range)
@@ -1881,7 +1939,7 @@ def __call__(
device=self.device,
)
_grid: torch.Tensor
- _grid, self.affine = affine_grid(spatial_size, grid)
+ _grid, self.affine = affine_grid(spatial_size, grid) # type: ignore
return _grid
def get_transformation_matrix(self) -> Optional[torch.Tensor]:
@@ -1940,14 +1998,12 @@ def __call__(self, spatial_size: Sequence[int]) -> torch.Tensor:
class Resample(Transform):
- backend = [TransformBackends.TORCH]
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
- as_tensor_output: bool = True,
norm_coords: bool = True,
device: Optional[torch.device] = None,
dtype: DtypeLike = np.float64,
@@ -1957,39 +2013,44 @@ def __init__(
supports spatially 2D or 3D (num_channels, H, W[, D]).
Args:
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `USE_COMPILED` is `True`, this argument uses
+ ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
+ See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.
+ See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to
`[0, size - 1]` (for ``monai/csrc`` implementation) or
`[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying
resampling API.
device: device on which the tensor will be allocated.
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
- self.mode: str = look_up_option(mode, GridSampleMode)
- self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
+ self.mode = mode
+ self.padding_mode = padding_mode
self.norm_coords = norm_coords
self.device = device
self.dtype = dtype
- def __call__( # type: ignore
+ def __call__(
self,
img: torch.Tensor,
- grid: torch.Tensor,
- mode: Optional[str] = None,
+ grid: Optional[torch.Tensor] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
dtype: DtypeLike = None,
) -> torch.Tensor:
@@ -2000,47 +2061,82 @@ def __call__( # type: ignore
if ``norm_coords`` is True, the grid values must be in `[-(size-1)/2, (size-1)/2]`.
if ``USE_COMPILED=True`` and ``norm_coords=False``, grid values must be in `[0, size-1]`.
if ``USE_COMPILED=False`` and ``norm_coords=False``, grid values must be in `[-1, 1]`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `USE_COMPILED` is `True`, this argument uses
``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.
+ See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
dtype: data type for resampling computation. Defaults to ``self.dtype``.
To be compatible with other modules, the output data type is always `float32`.
See also:
:py:const:`monai.config.USE_COMPILED`
"""
+ img = convert_to_tensor(img, track_meta=get_track_meta())
+ if grid is None:
+ return img
_device = img.device if isinstance(img, torch.Tensor) else self.device
_dtype = dtype or self.dtype or img.dtype
- img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype, device=_device)
- grid_t, *_ = convert_to_dst_type(grid, img_t)
- if grid_t is grid: # copy if needed (convert_data_type converts to contiguous)
- grid_t = grid_t.clone(memory_format=torch.contiguous_format)
+ grid_t, *_ = convert_to_dst_type(grid, img_t, dtype=grid.dtype, wrap_sequence=True)
+ grid_t = grid_t.clone(memory_format=torch.contiguous_format)
+
+ if self.norm_coords:
+ grid_t[-1] = where(grid_t[-1] != 0, grid_t[-1], 1.0) # type: ignore
sr = min(len(img_t.shape[1:]), 3)
- if USE_COMPILED:
+ _interp_mode = self.mode if mode is None else mode
+ _padding_mode = self.padding_mode if padding_mode is None else padding_mode
+ if look_up_option(str(_interp_mode), SplineMode, default=None) is not None:
+ self._backend = TransformBackends.NUMPY
+ else:
+ self._backend = TransformBackends.TORCH
+
+ if USE_COMPILED or self._backend == TransformBackends.NUMPY:
if self.norm_coords:
for i, dim in enumerate(img_t.shape[1 : 1 + sr]):
grid_t[i] = (max(dim, 2) / 2.0 - 0.5 + grid_t[i]) / grid_t[-1:]
- grid_t = moveaxis(grid_t[:sr], 0, -1) # type: ignore
- _padding_mode = self.padding_mode if padding_mode is None else padding_mode
- bound = 1 if _padding_mode == "reflection" else _padding_mode
- _interp_mode = self.mode if mode is None else mode
- if _interp_mode == "bicubic":
- interp = 3
- elif _interp_mode == "bilinear":
- interp = 1
- else:
- interp = _interp_mode # type: ignore
- out = grid_pull(
- img_t.unsqueeze(0), grid_t.unsqueeze(0), bound=bound, extrapolate=True, interpolation=interp
- )[0]
+ grid_t = grid_t[:sr]
+ if USE_COMPILED and self._backend == TransformBackends.TORCH: # compiled is using torch backend param name
+ grid_t = moveaxis(grid_t, 0, -1) # type: ignore
+ bound = 1 if _padding_mode == "reflection" else _padding_mode
+ if _interp_mode == "bicubic":
+ interp = 3
+ elif _interp_mode == "bilinear":
+ interp = 1
+ else:
+ interp = GridSampleMode(_interp_mode) # type: ignore
+ out = grid_pull(
+ img_t.unsqueeze(0),
+ grid_t.unsqueeze(0).to(img_t),
+ bound=bound,
+ extrapolate=True,
+ interpolation=interp,
+ )[0]
+ elif self._backend == TransformBackends.NUMPY:
+ is_cuda = img_t.is_cuda
+ img_np = (convert_to_cupy if is_cuda else convert_to_numpy)(img_t, wrap_sequence=True)
+ grid_np, *_ = convert_to_dst_type(grid_t, img_np, wrap_sequence=True)
+ _map_coord = (cupy_ndi if is_cuda else np_ndi).map_coordinates
+ out = (cupy if is_cuda else np).stack(
+ [
+ _map_coord(c, grid_np, order=int(_interp_mode), mode=look_up_option(_padding_mode, NdimageMode))
+ for c in img_np
+ ]
+ )
+ out = convert_to_dst_type(out, img_t)[0]
else:
if self.norm_coords:
for i, dim in enumerate(img_t.shape[1 : 1 + sr]):
@@ -2049,9 +2145,9 @@ def __call__( # type: ignore
grid_t = moveaxis(grid_t[index_ordering], 0, -1) # type: ignore
out = torch.nn.functional.grid_sample(
img_t.unsqueeze(0),
- grid_t.unsqueeze(0),
- mode=self.mode if mode is None else GridSampleMode(mode),
- padding_mode=self.padding_mode if padding_mode is None else GridSamplePadMode(padding_mode),
+ grid_t.unsqueeze(0).to(img_t),
+ mode=GridSampleMode(_interp_mode),
+ padding_mode=GridSamplePadMode(_padding_mode),
align_corners=True,
)[0]
out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32)
@@ -2067,7 +2163,6 @@ class Affine(InvertibleTransform):
backend = list(set(AffineGrid.backend) & set(Resample.backend))
- @deprecated_arg(name="as_tensor_output", since="0.6")
@deprecated_arg(name="norm_coords", since="0.8")
def __init__(
self,
@@ -2077,11 +2172,10 @@ def __init__(
scale_params: Optional[Union[Sequence[float], float]] = None,
affine: Optional[NdarrayOrTensor] = None,
spatial_size: Optional[Union[Sequence[int], int]] = None,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.REFLECTION,
normalized: bool = False,
norm_coords: bool = True,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
dtype: DtypeLike = np.float32,
image_only: bool = False,
@@ -2115,28 +2209,29 @@ def __init__(
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
normalized: indicating whether the provided `affine` is defined to include a normalization
transform converting the coordinates from `[-(size-1)/2, (size-1)/2]` (defined in ``create_grid``) to
`[0, size - 1]` or `[-1, 1]` in order to be compatible with the underlying resampling API.
If `normalized=False`, additional coordinate normalization will be applied before resampling.
See also: :py:func:`monai.networks.utils.normalize_transform`.
device: device on which the tensor will be allocated.
- dtype: data type for resampling computation. Defaults to ``np.float32``.
+ dtype: data type for resampling computation. Defaults to ``float32``.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.
image_only: if True return only the image volume, otherwise return (image, affine).
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
.. deprecated:: 0.8.1
``norm_coords`` is deprecated, please use ``normalized`` instead
(the new flag is a negation, i.e., ``norm_coords == not normalized``).
@@ -2155,14 +2250,14 @@ def __init__(
self.norm_coord = not normalized
self.resampler = Resample(norm_coords=self.norm_coord, device=device, dtype=dtype)
self.spatial_size = spatial_size
- self.mode: str = look_up_option(mode, GridSampleMode)
- self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
+ self.mode = mode
+ self.padding_mode: str = padding_mode
def __call__(
self,
img: torch.Tensor,
spatial_size: Optional[Union[Sequence[int], int]] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]:
"""
@@ -2173,21 +2268,24 @@ def __call__(
the transform will use the spatial size of `img`.
if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].
if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- When `USE_COMPILED` is `True`, this argument uses
- ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
- See also: https://docs.monai.io/en/stable/networks.html#grid-pull
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
img_size = img.shape[1:]
sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size)
- _mode = mode or self.mode
- _padding_mode = padding_mode or self.padding_mode
+ _mode = mode if mode is not None else self.mode
+ _padding_mode = padding_mode if padding_mode is not None else self.padding_mode
grid, affine = self.affine_grid(spatial_size=sp_size)
out = self.resampler(img, grid=grid, mode=_mode, padding_mode=_padding_mode)
if not isinstance(out, MetaTensor):
@@ -2231,7 +2329,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
out = MetaTensor(out)
out.meta = data.meta # type: ignore
self.update_meta(out, inv_affine, data.shape[1:], orig_size)
- return out # type: ignore
+ return out
class RandAffine(RandomizableTransform, InvertibleTransform):
@@ -2243,7 +2341,6 @@ class RandAffine(RandomizableTransform, InvertibleTransform):
backend = Affine.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
prob: float = 0.1,
@@ -2252,10 +2349,9 @@ def __init__(
translate_range: RandRange = None,
scale_range: RandRange = None,
spatial_size: Optional[Union[Sequence[int], int]] = None,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.REFLECTION,
cache_grid: bool = False,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
) -> None:
"""
@@ -2290,12 +2386,18 @@ def __init__(
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
- mode: {``"bilinear"``, ``"nearest"``}
- Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
+ Interpolation mode to calculate output values. Defaults to ``bilinear``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"reflection"``.
+ Padding mode for outside grid values. Defaults to ``reflection``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
cache_grid: whether to cache the identity sampling grid.
If the spatial size is not dynamically defined by input image, enabling this option could
accelerate the transform.
@@ -2305,9 +2407,6 @@ def __init__(
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- :py:class:`Affine` for the affine transformation parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
RandomizableTransform.__init__(self, prob)
@@ -2323,8 +2422,8 @@ def __init__(
self.spatial_size = spatial_size
self.cache_grid = cache_grid
self._cached_grid = self._init_identity_cache()
- self.mode: str = GridSampleMode(mode)
- self.padding_mode: str = GridSamplePadMode(padding_mode)
+ self.mode = mode
+ self.padding_mode: str = padding_mode
def _init_identity_cache(self):
"""
@@ -2383,7 +2482,7 @@ def __call__(
self,
img: torch.Tensor,
spatial_size: Optional[Union[Sequence[int], int]] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
randomize: bool = True,
grid=None,
@@ -2396,12 +2495,18 @@ def __call__(
the transform will use the spatial size of `img`.
if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].
if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
randomize: whether to execute `randomize()` function first, default to True.
grid: precomputed grid to be used (mainly to accelerate `RandAffined`).
@@ -2412,8 +2517,8 @@ def __call__(
# except convert to float and device
sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:])
do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:]))
- _mode = mode or self.mode
- _padding_mode = padding_mode or self.padding_mode
+ _mode = mode if mode is not None else self.mode
+ _padding_mode = padding_mode if padding_mode is not None else self.padding_mode
img = convert_to_tensor(img, track_meta=get_track_meta())
if not do_resampling:
out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0]
@@ -2465,7 +2570,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
out = MetaTensor(out)
out.meta = data.meta # type: ignore
self.update_meta(out, inv_affine, data.shape[1:], orig_size)
- return out # type: ignore
+ return out
class Rand2DElastic(RandomizableTransform):
@@ -2477,7 +2582,6 @@ class Rand2DElastic(RandomizableTransform):
backend = Resample.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
spacing: Union[Tuple[float, float], float],
@@ -2488,9 +2592,8 @@ def __init__(
translate_range: RandRange = None,
scale_range: RandRange = None,
spatial_size: Optional[Union[Tuple[int, int], int]] = None,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.REFLECTION,
- as_tensor_output: bool = False,
device: Optional[torch.device] = None,
) -> None:
"""
@@ -2526,21 +2629,24 @@ def __init__(
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
device: device on which the tensor will be allocated.
See also:
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- :py:class:`Affine` for the affine transformation parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
RandomizableTransform.__init__(self, prob)
self.deform_grid = RandDeformGrid(spacing=spacing, magnitude_range=magnitude_range, device=device)
@@ -2555,8 +2661,8 @@ def __init__(
self.device = device
self.spatial_size = spatial_size
- self.mode: str = look_up_option(mode, GridSampleMode)
- self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
+ self.mode = mode
+ self.padding_mode: str = padding_mode
def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
@@ -2566,6 +2672,12 @@ def set_random_state(
super().set_random_state(seed, state)
return self
+ def set_device(self, device):
+ self.deform_grid.device = device
+ self.rand_affine_grid.device = device
+ self.resampler.device = device
+ self.device = device
+
def randomize(self, spatial_size: Sequence[int]) -> None:
super().randomize(None)
if not self._do_transform:
@@ -2577,7 +2689,7 @@ def __call__(
self,
img: torch.Tensor,
spatial_size: Optional[Union[Tuple[int, int], int]] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
randomize: bool = True,
) -> torch.Tensor:
@@ -2587,12 +2699,18 @@ def __call__(
spatial_size: specifying output image spatial size [h, w].
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
the transform will use the spatial size of `img`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
randomize: whether to execute `randomize()` function first, default to True.
"""
sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:])
@@ -2614,7 +2732,10 @@ def __call__(
_device = img.device if isinstance(img, torch.Tensor) else self.device
grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")
out: torch.Tensor = self.resampler(
- img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
+ img,
+ grid,
+ mode=mode if mode is not None else self.mode,
+ padding_mode=padding_mode if padding_mode is not None else self.padding_mode,
)
return out
@@ -2628,7 +2749,6 @@ class Rand3DElastic(RandomizableTransform):
backend = Resample.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
sigma_range: Tuple[float, float],
@@ -2639,9 +2759,8 @@ def __init__(
translate_range: RandRange = None,
scale_range: RandRange = None,
spatial_size: Optional[Union[Tuple[int, int, int], int]] = None,
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.REFLECTION,
- as_tensor_output: bool = False,
device: Optional[torch.device] = None,
) -> None:
"""
@@ -2680,21 +2799,24 @@ def __init__(
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted
to `(32, 32, 64)` if the third spatial dimension size of img is `64`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
device: device on which the tensor will be allocated.
See also:
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- :py:class:`Affine` for the affine transformation parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
RandomizableTransform.__init__(self, prob)
self.rand_affine_grid = RandAffineGrid(
@@ -2709,8 +2831,8 @@ def __init__(
self.sigma_range = sigma_range
self.magnitude_range = magnitude_range
self.spatial_size = spatial_size
- self.mode: str = look_up_option(mode, GridSampleMode)
- self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode)
+ self.mode = mode
+ self.padding_mode: str = padding_mode
self.device = device
self.rand_offset: np.ndarray
@@ -2724,6 +2846,11 @@ def set_random_state(
super().set_random_state(seed, state)
return self
+ def set_device(self, device):
+ self.rand_affine_grid.device = device
+ self.resampler.device = device
+ self.device = device
+
def randomize(self, grid_size: Sequence[int]) -> None:
super().randomize(None)
if not self._do_transform:
@@ -2737,7 +2864,7 @@ def __call__(
self,
img: torch.Tensor,
spatial_size: Optional[Union[Tuple[int, int, int], int]] = None,
- mode: Optional[str] = None,
+ mode: Union[str, int, None] = None,
padding_mode: Optional[str] = None,
randomize: bool = True,
) -> torch.Tensor:
@@ -2747,12 +2874,18 @@ def __call__(
spatial_size: specifying spatial 3D output image spatial size [h, w, d].
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
the transform will use the spatial size of `img`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
randomize: whether to execute `randomize()` function first, default to True.
"""
sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:])
@@ -2769,7 +2902,10 @@ def __call__(
grid[:3] += gaussian(offset)[0] * self.magnitude
grid = self.rand_affine_grid(grid=grid)
out: torch.Tensor = self.resampler(
- img, grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode
+ img,
+ grid, # type: ignore
+ mode=mode if mode is not None else self.mode,
+ padding_mode=padding_mode if padding_mode is not None else self.padding_mode,
)
return out
@@ -2782,7 +2918,7 @@ def __init__(
self,
num_cells: Union[Tuple[int], int],
distort_steps: Sequence[Sequence[float]],
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
device: Optional[torch.device] = None,
) -> None:
@@ -2795,12 +2931,18 @@ def __init__(
distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
Each value in the tuple represents the distort step of the related cell.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
device: device on which the tensor will be allocated.
"""
@@ -2822,12 +2964,18 @@ def __call__(
distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
Each value in the tuple represents the distort step of the related cell.
- mode: {``"bilinear"``, ``"nearest"``}
- Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
+ Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"border"``.
+ Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
"""
distort_steps = self.distort_steps if distort_steps is None else distort_steps
@@ -2857,7 +3005,7 @@ def __call__(
coords = meshgrid_ij(*all_ranges)
grid = torch.stack([*coords, torch.ones_like(coords[0])])
- return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode) # type: ignore
+ return self.resampler(img, grid=grid, mode=mode, padding_mode=padding_mode)
class RandGridDistortion(RandomizableTransform):
@@ -2869,7 +3017,7 @@ def __init__(
num_cells: Union[Tuple[int], int] = 5,
prob: float = 0.1,
distort_limit: Union[Tuple[float, float], float] = (-0.03, 0.03),
- mode: str = GridSampleMode.BILINEAR,
+ mode: Union[str, int] = GridSampleMode.BILINEAR,
padding_mode: str = GridSamplePadMode.BORDER,
device: Optional[torch.device] = None,
) -> None:
@@ -2883,12 +3031,18 @@ def __init__(
distort_limit: range to randomly distort.
If single number, distort_limit is picked from (-distort_limit, distort_limit).
Defaults to (-0.03, 0.03).
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
device: device on which the tensor will be allocated.
"""
@@ -2918,12 +3072,18 @@ def __call__(
"""
Args:
img: shape must be (num_channels, H, W[, D]).
- mode: {``"bilinear"``, ``"nearest"``}
- Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
+ Interpolation mode to calculate output values. Defaults to ``self.mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"border"``.
+ Padding mode for outside grid values. Defaults to ``self.padding_mode``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
randomize: whether to shuffle the random factors using `randomize()`, default to True.
"""
if randomize:
@@ -3039,6 +3199,9 @@ class GridPatch(Transform):
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
+ Returns:
+ MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
+
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
@@ -3122,21 +3285,24 @@ def __call__(self, array: NdarrayOrTensor):
elif self.threshold:
patched_image, locations = self.filter_threshold(patched_image, locations)
- # Convert to original data type
- output = list(
- zip(
- convert_to_dst_type(src=patched_image, dst=array)[0],
- convert_to_dst_type(src=locations, dst=array, dtype=int)[0],
- )
- )
-
# Pad the patch list to have the requested number of patches
- if self.num_patches and len(output) < self.num_patches:
- patch = convert_to_dst_type(
- src=np.full((array.shape[0], *self.patch_size), self.pad_kwargs.get("constant_values", 0)), dst=array
- )[0]
- start_location = convert_to_dst_type(src=np.zeros(len(self.patch_size)), dst=array)[0]
- output += [(patch, start_location)] * (self.num_patches - len(output))
+ if self.num_patches:
+ padding = self.num_patches - len(patched_image)
+ if padding > 0:
+ patched_image = np.pad(
+ patched_image,
+ [[0, padding], [0, 0]] + [[0, 0]] * len(self.patch_size),
+ constant_values=self.pad_kwargs.get("constant_values", 0),
+ )
+ locations = np.pad(locations, [[0, padding], [0, 0]], constant_values=0)
+
+ # Convert to MetaTensor
+ metadata = array.meta if isinstance(array, MetaTensor) else MetaTensor.get_default_meta()
+ metadata[WSIPatchKeys.LOCATION] = locations.T
+ metadata[WSIPatchKeys.COUNT] = len(locations)
+ metadata["spatial_shape"] = np.tile(np.array(self.patch_size), (len(locations), 1)).T
+ output = MetaTensor(x=patched_image, meta=metadata)
+ output.is_batch = True
return output
@@ -3162,6 +3328,9 @@ class RandGridPatch(GridPatch, RandomizableTransform):
pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied. Defaults to ``"constant"``.
pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
+ Returns:
+ MetaTensor: A MetaTensor consisting of a batch of all the patches with associated metadata
+
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py
index 493369d2584..706e8d7f8b4 100644
--- a/monai/transforms/spatial/dictionary.py
+++ b/monai/transforms/spatial/dictionary.py
@@ -15,7 +15,6 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""
-from copy import deepcopy
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
@@ -58,12 +57,10 @@
GridSamplePadMode,
InterpolateMode,
NumpyPadMode,
- WSIPatchKeys,
convert_to_tensor,
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
- first,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import PytorchPadMode, TraceKeys
@@ -178,20 +175,26 @@ def __init__(
"""
Args:
keys: keys of the corresponding items to be transformed.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
- See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
- See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
dst_keys: the key of the corresponding ``dst_affine`` in the metadata dictionary.
allow_missing_keys: don't raise exception if key is missing.
@@ -248,20 +251,26 @@ def __init__(
Args:
keys: keys of the corresponding items to be transformed.
key_dst: key of image to resample to match.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
- See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
- See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
@@ -322,8 +331,12 @@ def __init__(
padding_mode: SequenceStr = GridSamplePadMode.BORDER,
align_corners: Union[Sequence[bool], bool] = False,
dtype: Union[Sequence[DtypeLike], DtypeLike] = np.float64,
+ scale_extent: bool = False,
+ recompute_affine: bool = False,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
+ min_pixdim: Union[Sequence[float], float, None] = None,
+ max_pixdim: Union[Sequence[float], float, None] = None,
allow_missing_keys: bool = False,
) -> None:
"""
@@ -347,39 +360,66 @@ def __init__(
translations components from the original affine will be
preserved in the target affine. This option will not flip/swap
axes against the original ones.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
It also can be a sequence of dtypes, each element corresponds to a key in ``keys``.
+ scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
+ default False. The option is ignored if output spatial size is specified when calling this transform.
+ See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
+ should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.
+ recompute_affine: whether to recompute affine based on the output shape. The affine computed
+ analytically does not reflect the potential quantization errors in terms of the output shape.
+ Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.
+ min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this
+ value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
+ value of `pixdim`. Default to `None`.
+ max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
+ value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
+ value of `pixdim`. Default to `None`.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
- self.spacing_transform = Spacing(pixdim, diagonal=diagonal)
+ self.spacing_transform = Spacing(
+ pixdim, diagonal=diagonal, recompute_affine=recompute_affine, min_pixdim=min_pixdim, max_pixdim=max_pixdim
+ )
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys))
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
+ self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys))
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d: Dict = dict(data)
- for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
- d, self.mode, self.padding_mode, self.align_corners, self.dtype
+ for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator(
+ d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent
):
# resample array of each corresponding key
d[key] = self.spacing_transform(
- data_array=d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype
+ data_array=d[key],
+ mode=mode,
+ padding_mode=padding_mode,
+ align_corners=align_corners,
+ dtype=dtype,
+ scale_extent=scale_extent,
)
return d
@@ -569,6 +609,15 @@ class Resized(MapTransform, InvertibleTransform):
'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
It also can be a sequence of bool or None, each element corresponds to a key in ``keys``.
+ anti_aliasing: bool
+ Whether to apply a Gaussian filter to smooth the image prior
+ to downsampling. It is crucial to filter when downsampling
+ the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
+ anti_aliasing_sigma: {float, tuple of floats}, optional
+ Standard deviation for Gaussian filtering used when anti-aliasing.
+ By default, this value is chosen as (s - 1) / 2 where s is the
+ downsampling factor, where s > 1. For the up-size case, s < 1, no
+ anti-aliasing is performed prior to rescaling.
allow_missing_keys: don't raise exception if key is missing.
"""
@@ -581,17 +630,29 @@ def __init__(
size_mode: str = "all",
mode: SequenceStr = InterpolateMode.AREA,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
+ anti_aliasing: Union[Sequence[bool], bool] = False,
+ anti_aliasing_sigma: Union[Sequence[Union[Sequence[float], float, None]], Sequence[float], float, None] = None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
+ self.anti_aliasing = ensure_tuple_rep(anti_aliasing, len(self.keys))
+ self.anti_aliasing_sigma = ensure_tuple_rep(anti_aliasing_sigma, len(self.keys))
self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode)
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
- for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners):
- d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners)
+ for key, mode, align_corners, anti_aliasing, anti_aliasing_sigma in self.key_iterator(
+ d, self.mode, self.align_corners, self.anti_aliasing, self.anti_aliasing_sigma
+ ):
+ d[key] = self.resizer(
+ d[key],
+ mode=mode,
+ align_corners=align_corners,
+ anti_aliasing=anti_aliasing,
+ anti_aliasing_sigma=anti_aliasing_sigma,
+ )
return d
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
@@ -608,7 +669,6 @@ class Affined(MapTransform, InvertibleTransform):
backend = Affine.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
@@ -620,7 +680,6 @@ def __init__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: SequenceStr = GridSampleMode.BILINEAR,
padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
dtype: Union[DtypeLike, torch.dtype] = np.float32,
allow_missing_keys: bool = False,
@@ -653,16 +712,22 @@ def __init__(
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
- dtype: data type for resampling computation. Defaults to ``np.float32``.
+ dtype: data type for resampling computation. Defaults to ``float32``.
If ``None``, use the data type of input data. To be compatible with other modules,
the output data type is always `float32`.
allow_missing_keys: don't raise exception if key is missing.
@@ -671,9 +736,6 @@ def __init__(
- :py:class:`monai.transforms.compose.MapTransform`
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
MapTransform.__init__(self, keys, allow_missing_keys)
self.affine = Affine(
@@ -709,7 +771,6 @@ class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform):
backend = RandAffine.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
@@ -722,7 +783,6 @@ def __init__(
mode: SequenceStr = GridSampleMode.BILINEAR,
padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,
cache_grid: bool = False,
- as_tensor_output: bool = True,
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
@@ -759,14 +819,20 @@ def __init__(
scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
This allows 0 to correspond to no change (i.e., a scaling of 1.0).
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
cache_grid: whether to cache the identity sampling grid.
If the spatial size is not dynamically defined by input image, enabling this option could
accelerate the transform.
@@ -777,9 +843,6 @@ def __init__(
- :py:class:`monai.transforms.compose.MapTransform`
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
@@ -805,8 +868,8 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
@@ -814,7 +877,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
# all the keys share the same random Affine factor
self.rand_affine.randomize()
- spatial_size = d[first_key].shape[1:] # type: ignore
+ spatial_size = d[first_key].shape[1:]
+
sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size)
# change image size or do random transform
do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size))
@@ -828,7 +892,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
# do the transform
if do_resampling:
- d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid)
+ d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore
+ else:
+ d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
xform = self.pop_transform(d[key], check=False) if do_resampling else {}
self.push_transform(d[key], extra_info={"do_resampling": do_resampling, "rand_affine_info": xform})
@@ -841,7 +907,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
do_resampling = tr[TraceKeys.EXTRA_INFO]["do_resampling"]
if do_resampling:
d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO]["rand_affine_info"]) # type: ignore
- d[key] = self.rand_affine.inverse(d[key])
+ d[key] = self.rand_affine.inverse(d[key]) # type: ignore
return d
@@ -853,7 +919,6 @@ class Rand2DElasticd(RandomizableTransform, MapTransform):
backend = Rand2DElastic.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
@@ -867,7 +932,6 @@ def __init__(
scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None,
mode: SequenceStr = GridSampleMode.BILINEAR,
padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,
- as_tensor_output: bool = False,
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
@@ -906,14 +970,20 @@ def __init__(
scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
This allows 0 to correspond to no change (i.e., a scaling of 1.0).
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
allow_missing_keys: don't raise exception if key is missing.
@@ -921,9 +991,6 @@ def __init__(
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- :py:class:`Affine` for the affine transformation parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
@@ -950,14 +1017,19 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+
+ if first_key == ():
out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
self.randomize(None)
+ device = self.rand_2d_elastic.device
+ if device is None and isinstance(d[first_key], torch.Tensor):
+ device = d[first_key].device # type: ignore
+ self.rand_2d_elastic.set_device(device)
+ sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:])
- sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore
# all the keys share the same random elastic factor
self.rand_2d_elastic.randomize(sp_size)
@@ -973,11 +1045,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
)
grid = CenterSpatialCrop(roi_size=sp_size)(grid[0])
else:
- _device = self.rand_2d_elastic.deform_grid.device
- grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")
+ grid = create_grid(spatial_size=sp_size, device=device, backend="torch")
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
- d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
+ d[key] = self.rand_2d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # type: ignore
return d
@@ -988,7 +1059,6 @@ class Rand3DElasticd(RandomizableTransform, MapTransform):
backend = Rand3DElastic.backend
- @deprecated_arg(name="as_tensor_output", since="0.6")
def __init__(
self,
keys: KeysCollection,
@@ -1002,7 +1072,6 @@ def __init__(
scale_range: Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] = None,
mode: SequenceStr = GridSampleMode.BILINEAR,
padding_mode: SequenceStr = GridSamplePadMode.REFLECTION,
- as_tensor_output: bool = False,
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
@@ -1043,14 +1112,20 @@ def __init__(
scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
This allows 0 to correspond to no change (i.e., a scaling of 1.0).
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"reflection"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
allow_missing_keys: don't raise exception if key is missing.
@@ -1058,9 +1133,6 @@ def __init__(
- :py:class:`RandAffineGrid` for the random affine parameters configurations.
- :py:class:`Affine` for the affine transformation parameters configurations.
- .. deprecated:: 0.6.0
- ``as_tensor_output`` is deprecated.
-
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
@@ -1087,28 +1159,32 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+
+ if first_key == ():
out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
self.randomize(None)
- sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore
+ sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:])
+
# all the keys share the same random elastic factor
self.rand_3d_elastic.randomize(sp_size)
- _device = self.rand_3d_elastic.device
- grid = create_grid(spatial_size=sp_size, device=_device, backend="torch")
+ device = self.rand_3d_elastic.device
+ if device is None and isinstance(d[first_key], torch.Tensor):
+ device = d[first_key].device
+ self.rand_3d_elastic.set_device(device)
+ grid = create_grid(spatial_size=sp_size, device=device, backend="torch")
if self._do_transform:
- device = self.rand_3d_elastic.device
gaussian = GaussianFilter(spatial_dims=3, sigma=self.rand_3d_elastic.sigma, truncated=3.0).to(device)
offset = torch.as_tensor(self.rand_3d_elastic.rand_offset, device=device).unsqueeze(0)
grid[:3] += gaussian(offset)[0] * self.rand_3d_elastic.magnitude
grid = self.rand_3d_elastic.rand_affine_grid(grid=grid)
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
- d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode)
+ d[key] = self.rand_3d_elastic.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) # type: ignore
return d
@@ -1237,14 +1313,15 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
return d
self.randomize(None)
# all the keys share the same random selected axis
- self.flipper.randomize(d[first_key]) # type: ignore
+ self.flipper.randomize(d[first_key])
+
for key in self.key_iterator(d):
if self._do_transform:
d[key] = self.flipper(d[key], randomize=False)
@@ -1286,9 +1363,9 @@ class Rotated(MapTransform, InvertibleTransform):
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
- dtype: data type for resampling computation. Defaults to ``np.float32``.
+ dtype: data type for resampling computation. Defaults to ``float32``.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
@@ -1359,9 +1436,9 @@ class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform):
align_corners: Defaults to False.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
It also can be a sequence of bool, each element corresponds to a key in ``keys``.
- dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
+ dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
- the output data type is always ``np.float32``.
+ the output data type is always ``float32``.
It also can be a sequence of dtype or None, each element corresponds to a key in ``keys``.
allow_missing_keys: don't raise exception if key is missing.
"""
@@ -1416,7 +1493,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
randomize=False,
)
else:
- d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
+ d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=rot_info)
@@ -1567,15 +1644,16 @@ def set_random_state(
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
self.randomize(None)
# all the keys share the same random zoom factor
- self.rand_zoom.randomize(d[first_key]) # type: ignore
+ self.rand_zoom.randomize(d[first_key])
+
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
@@ -1584,7 +1662,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False
)
else:
- d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
+ d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform)
@@ -1624,14 +1702,20 @@ def __init__(
distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
Each value in the tuple represents the distort step of the related cell.
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"reflection"``.
+ Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
allow_missing_keys: don't raise exception if key is missing.
@@ -1674,14 +1758,20 @@ def __init__(
distort_limit: range to randomly distort.
If single number, distort_limit is picked from (-distort_limit, distort_limit).
Defaults to (-0.03, 0.03).
- mode: {``"bilinear"``, ``"nearest"``}
+ mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
+ and the value represents the order of the spline interpolation.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
- Padding mode for outside grid values. Defaults to ``"reflection"``.
+ Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
- It also can be a sequence of string, each element corresponds to a key in ``keys``.
+ When `mode` is an integer, using numpy/cupy backends, this argument accepts
+ {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ It also can be a sequence, each element corresponds to a key in ``keys``.
device: device on which the tensor will be allocated.
allow_missing_keys: don't raise exception if key is missing.
@@ -1708,12 +1798,13 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
out: Dict[Hashable, torch.Tensor] = convert_to_tensor(d, track_meta=get_track_meta())
return out
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
out = convert_to_tensor(d, track_meta=get_track_meta())
return out
- self.rand_grid_distortion.randomize(d[first_key].shape[1:]) # type: ignore
+ self.rand_grid_distortion.randomize(d[first_key].shape[1:])
+
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False)
return d
@@ -1817,25 +1908,11 @@ def __init__(
**pad_kwargs,
)
- def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]:
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
- original_spatial_shape = d[first(self.keys)].shape[1:]
- output = []
- results = [self.patcher(d[key]) for key in self.keys]
- num_patches = min(len(r) for r in results)
- for patch in zip(*results):
- new_dict = {k: v[0] for k, v in zip(self.keys, patch)}
- # fill in the extra keys with unmodified data
- for k in set(d.keys()).difference(set(self.keys)):
- new_dict[k] = deepcopy(d[k])
- # fill additional metadata
- new_dict["original_spatial_shape"] = original_spatial_shape
- new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item
- new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size
- new_dict[WSIPatchKeys.COUNT] = num_patches
- new_dict["offset"] = self.patcher.offset
- output.append(new_dict)
- return output
+ for key in self.key_iterator(d):
+ d[key] = self.patcher(d[key])
+ return d
class RandGridPatchd(RandomizableTransform, MapTransform):
@@ -1908,31 +1985,15 @@ def set_random_state(
self.patcher.set_random_state(seed, state)
return self
- def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict]:
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
- original_spatial_shape = d[first(self.keys)].shape[1:]
- # all the keys share the same random noise
- first_key: Union[Hashable, List] = self.first_key(d)
- if first_key == []:
- return [d]
- self.patcher.randomize(d[first_key]) # type: ignore
- results = [self.patcher(d[key], randomize=False) for key in self.keys]
-
- num_patches = min(len(r) for r in results)
- output = []
- for patch in zip(*results):
- new_dict = {k: v[0] for k, v in zip(self.keys, patch)}
- # fill in the extra keys with unmodified data
- for k in set(d.keys()).difference(set(self.keys)):
- new_dict[k] = deepcopy(d[k])
- # fill additional metadata
- new_dict["original_spatial_shape"] = original_spatial_shape
- new_dict[WSIPatchKeys.LOCATION] = patch[0][1] # use the starting coordinate of the first item
- new_dict[WSIPatchKeys.SIZE] = self.patcher.patch_size
- new_dict[WSIPatchKeys.COUNT] = num_patches
- new_dict["offset"] = self.patcher.offset
- output.append(new_dict)
- return output
+ # All the keys share the same random noise
+ for key in self.key_iterator(d):
+ self.patcher.randomize(d[key])
+ break
+ for key in self.key_iterator(d):
+ d[key] = self.patcher(d[key], randomize=False)
+ return d
SpatialResampleD = SpatialResampleDict = SpatialResampled
diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py
index 730cb634c0a..b1a7d9b4db4 100644
--- a/monai/transforms/transform.py
+++ b/monai/transforms/transform.py
@@ -24,8 +24,20 @@
from monai.data.meta_tensor import MetaTensor
from monai.utils import MAX_SEED, ensure_tuple, first
from monai.utils.enums import TransformBackends
-
-__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
+from monai.utils.misc import MONAIEnvVars
+
+__all__ = [
+ "ThreadUnsafe",
+ "apply_transform",
+ "LazyTrait",
+ "RandomizableTrait",
+ "MultiSampleTrait",
+ "Randomizable",
+ "LazyTransform",
+ "RandomizableTransform",
+ "Transform",
+ "MapTransform",
+]
ReturnType = TypeVar("ReturnType")
@@ -89,7 +101,10 @@ def apply_transform(
return [_apply_transform(transform, item, unpack_items) for item in data]
return _apply_transform(transform, data, unpack_items)
except Exception as e:
-
+ # if in debug mode, don't swallow exception so that the breakpoint
+ # appears where the exception was raised.
+ if MONAIEnvVars.debug():
+ raise
if log_stats and not isinstance(transform, transforms.compose.Compose):
# log the input data information of exact transform in the transform chain
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)
@@ -114,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"):
raise RuntimeError(f"applying transform {transform}") from e
+class LazyTrait:
+ """
+ An interface to indicate that the transform has the capability to execute using
+ MONAI's lazy resampling feature. In order to do this, the implementing class needs
+ to be able to describe its operation as an affine matrix or grid with accompanying metadata.
+ This interface can be extended from by people adapting transforms to the MONAI framework as
+ well as by implementors of MONAI transforms.
+ """
+
+ @property
+ def lazy_evaluation(self):
+ """
+ Get whether lazy_evaluation is enabled for this transform instance.
+ Returns:
+ True if the transform is operating in a lazy fashion, False if not.
+ """
+ raise NotImplementedError()
+
+ @lazy_evaluation.setter
+ def lazy_evaluation(self, enabled: bool):
+ """
+ Set whether lazy_evaluation is enabled for this transform instance.
+ Args:
+ enabled: True if the transform should operate in a lazy fashion, False if not.
+ """
+ raise NotImplementedError()
+
+
+class RandomizableTrait:
+ """
+ An interface to indicate that the transform has the capability to perform
+ randomized transforms to the data that it is called upon. This interface
+ can be extended from by people adapting transforms to the MONAI framework as well as by
+ implementors of MONAI transforms.
+ """
+
+ pass
+
+
+class MultiSampleTrait:
+ """
+ An interface to indicate that the transform has the capability to return multiple samples
+ given an input, such as when performing random crops of a sample. This interface can be
+ extended from by people adapting transforms to the MONAI framework as well as by implementors
+ of MONAI transforms.
+ """
+
+ pass
+
+
class ThreadUnsafe:
"""
A class to denote that the transform will mutate its member variables,
@@ -127,7 +192,7 @@ class ThreadUnsafe:
pass
-class Randomizable(ABC, ThreadUnsafe):
+class Randomizable(ThreadUnsafe):
"""
An interface for handling random state locally, currently based on a class
variable `R`, which is an instance of `np.random.RandomState`. This
@@ -211,9 +276,13 @@ class Transform(ABC):
:py:class:`monai.transforms.Compose`
"""
- # Transforms should add data types to this list if they are capable of performing a transform without
- # modifying the input type. For example, ["torch.Tensor", "np.ndarray"] means that no copies of the data
- # are required if the input is either `torch.Tensor` or `np.ndarray`.
+ # Transforms should add `monai.transforms.utils.TransformBackends` to this list if they are performing
+ # the data processing using the corresponding backend APIs.
+ # Most of MONAI transform's inputs and outputs will be converted into torch.Tensor or monai.data.MetaTensor.
+ # This variable provides information about whether the input will be converted
+ # to other data types during the transformation. Note that not all `dtype` (such as float32, uint8) are supported
+ # by all the data types, the `dtype` during the conversion is determined automatically by each transform,
+ # please refer to the transform's docstring.
backend: List[TransformBackends] = []
@abstractmethod
@@ -243,7 +312,27 @@ def __call__(self, data: Any):
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
-class RandomizableTransform(Randomizable, Transform):
+class LazyTransform(Transform, LazyTrait):
+ """
+ An implementation of functionality for lazy transforms that can be subclassed by array and
+ dictionary transforms to simplify implementation of new lazy transforms.
+ """
+
+ def __init__(self, lazy_evaluation: Optional[bool] = True):
+ self.lazy_evaluation = lazy_evaluation
+
+ @property
+ def lazy_evaluation(self):
+ return self.lazy_evaluation
+
+ @lazy_evaluation.setter
+ def lazy_evaluation(self, lazy_evaluation: bool):
+ if not isinstance(lazy_evaluation, bool):
+ raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'")
+ self.lazy_evaluation = lazy_evaluation
+
+
+class RandomizableTransform(Randomizable, Transform, RandomizableTrait):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
@@ -407,10 +496,10 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Optional[
def first_key(self, data: Dict[Hashable, Any]):
"""
Get the first available key of `self.keys` in the input `data` dictionary.
- If no available key, return an empty list `[]`.
+ If no available key, return an empty tuple `()`.
Args:
data: data that the transform will be applied to.
"""
- return first(self.key_iterator(data), [])
+ return first(self.key_iterator(data), ())
diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py
index e001f101ce1..7a1f86c1e0a 100644
--- a/monai/transforms/utility/array.py
+++ b/monai/transforms/utility/array.py
@@ -58,7 +58,6 @@
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
cp, has_cp = optional_import("cupy")
-
__all__ = [
"Identity",
"AsChannelFirst",
@@ -201,44 +200,64 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
class EnsureChannelFirst(Transform):
"""
- Automatically adjust or add the channel dimension of input data to ensure `channel_first` shape.
- It extracts the `original_channel_dim` info from provided meta_data dictionary.
- Typical values of `original_channel_dim` can be: "no_channel", 0, -1.
- Convert the data to `channel_first` based on the `original_channel_dim` information.
+ Adjust or add the channel dimension of input data to ensure `channel_first` shape.
+
+ This extracts the `original_channel_dim` info from provided meta_data dictionary or MetaTensor input. This value
+ should state which dimension is the channel dimension so that it can be moved forward, or contain "no_channel" to
+ state no dimension is the channel and so a 1-size first dimension is to be added.
+
+ Args:
+ strict_check: whether to raise an error when the meta information is insufficient.
+ channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.
+ It overrides the `original_channel_dim` from provided MetaTensor input.
+ If the input array doesn't have a channel dim, this value should be ``'no_channel'``.
+ If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.
"""
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- def __init__(self, strict_check: bool = True):
- """
- Args:
- strict_check: whether to raise an error when the meta information is insufficient.
- """
+ def __init__(self, strict_check: bool = True, channel_dim: Union[None, str, int] = None):
self.strict_check = strict_check
+ self.input_channel_dim = channel_dim
def __call__(self, img: torch.Tensor, meta_dict: Optional[Mapping] = None) -> torch.Tensor:
"""
Apply the transform to `img`.
"""
if not isinstance(img, MetaTensor) and not isinstance(meta_dict, Mapping):
- msg = "metadata not available, EnsureChannelFirst is not in use."
- if self.strict_check:
- raise ValueError(msg)
- warnings.warn(msg)
- return img
+ if self.input_channel_dim is None:
+ msg = "Metadata not available and channel_dim=None, EnsureChannelFirst is not in use."
+ if self.strict_check:
+ raise ValueError(msg)
+ warnings.warn(msg)
+ return img
+ else:
+ img = MetaTensor(img)
+
if isinstance(img, MetaTensor):
meta_dict = img.meta
- channel_dim = meta_dict.get("original_channel_dim") # type: ignore
+
+ channel_dim = meta_dict.get("original_channel_dim", None) if isinstance(meta_dict, Mapping) else None
+ if self.input_channel_dim is not None:
+ channel_dim = self.input_channel_dim
if channel_dim is None:
- msg = "Unknown original_channel_dim in the meta_dict, EnsureChannelFirst is not in use."
+ msg = "Unknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`."
if self.strict_check:
raise ValueError(msg)
warnings.warn(msg)
return img
+
+ # track the original channel dim
+ if isinstance(meta_dict, dict):
+ meta_dict["original_channel_dim"] = channel_dim
+
if channel_dim == "no_channel":
- return convert_to_tensor(img[None], track_meta=get_track_meta()) # type: ignore
- return convert_to_tensor(moveaxis(img, channel_dim, 0), track_meta=get_track_meta()) # type: ignore
+ result = img[None]
+ else:
+ result = moveaxis(img, channel_dim, 0) # type: ignore
+
+ return convert_to_tensor(result, track_meta=get_track_meta()) # type: ignore
class RepeatChannel(Transform):
@@ -251,7 +270,7 @@ class RepeatChannel(Transform):
repeats: the number of repetitions for each element.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(self, repeats: int) -> None:
if repeats <= 0:
@@ -401,19 +420,19 @@ class ToTensor(Transform):
device: target device to put the converted Tensor data.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
- track_meta: whether to convert to `MetaTensor`, default to `False`, output type will be `torch.Tensor`.
- if `None`, use the return value of ``get_track_meta``.
+ track_meta: whether to convert to `MetaTensor` or regular tensor, default to `None`,
+ use the return value of ``get_track_meta``.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(
self,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
- track_meta: Optional[bool] = False,
+ track_meta: Optional[bool] = None,
) -> None:
super().__init__()
self.dtype = dtype
@@ -445,8 +464,8 @@ class EnsureType(Transform):
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
- track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
- If False, the output data type will be `torch.Tensor`. Default to the return value of ``get_track_meta``.
+ track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
+ if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.
"""
@@ -501,7 +520,7 @@ class ToNumpy(Transform):
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.NUMPY]
def __init__(self, dtype: DtypeLike = None, wrap_sequence: bool = True) -> None:
super().__init__()
@@ -528,7 +547,7 @@ class ToCupy(Transform):
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.CUPY]
def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None:
super().__init__()
@@ -547,7 +566,7 @@ class ToPIL(Transform):
Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.NUMPY]
def __call__(self, img):
"""
@@ -565,7 +584,7 @@ class Transpose(Transform):
Transposes the input image based on the given `indices` dimension ordering.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(self, indices: Optional[Sequence[int]]) -> None:
self.indices = None if indices is None else tuple(indices)
@@ -585,11 +604,12 @@ class SqueezeDim(Transform):
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
- def __init__(self, dim: Optional[int] = 0) -> None:
+ def __init__(self, dim: Optional[int] = 0, update_meta=True) -> None:
"""
Args:
dim: dimension to be squeezed. Default = 0
"None" works when the input is numpy array.
+ update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.
Raises:
TypeError: When ``dim`` is not an ``Optional[int]``.
@@ -598,6 +618,7 @@ def __init__(self, dim: Optional[int] = 0) -> None:
if dim is not None and not isinstance(dim, int):
raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.")
self.dim = dim
+ self.update_meta = update_meta
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
@@ -606,11 +627,25 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if self.dim is None:
+ if self.update_meta:
+ warnings.warn("update_meta=True is ignored when dim=None.")
return img.squeeze()
+ dim = (self.dim + len(img.shape)) if self.dim < 0 else self.dim
# for pytorch/numpy unification
- if img.shape[self.dim] != 1:
- raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape}.")
- return img.squeeze(self.dim)
+ if img.shape[dim] != 1:
+ raise ValueError(f"Can only squeeze singleton dimension, got shape {img.shape[dim]} of {img.shape}.")
+ img = img.squeeze(dim)
+ if self.update_meta and isinstance(img, MetaTensor) and dim > 0 and len(img.affine.shape) == 2:
+ h, w = img.affine.shape
+ affine, device = img.affine, img.affine.device if isinstance(img.affine, torch.Tensor) else None
+ if h > dim:
+ affine = affine[torch.arange(0, h, device=device) != dim - 1]
+ if w > dim:
+ affine = affine[:, torch.arange(0, w, device=device) != dim - 1]
+ if (affine.shape[0] == affine.shape[1]) and not np.linalg.det(convert_to_numpy(affine, wrap_sequence=True)):
+ warnings.warn(f"After SqueezeDim, img.affine is ill-posed: \n{img.affine}.")
+ img.affine = affine
+ return img
class DataStats(Transform):
@@ -1017,7 +1052,7 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
"""
Convert labels to multi channels based on brats18 classes:
label 1 is the necrotic and non-enhancing tumor core
- label 2 is the the peritumoral edema
+ label 2 is the peritumoral edema
label 4 is the GD-enhancing tumor
The possible classes are TC (Tumor core), WT (Whole tumor)
and ET (Enhancing tumor).
@@ -1056,7 +1091,7 @@ class AddExtremePointsChannel(Randomizable, Transform):
ValueError: When label image is not single channel.
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.TORCH]
def __init__(self, background: int = 0, pert: float = 0.0) -> None:
self._background = background
@@ -1388,7 +1423,7 @@ class AddCoordinateChannels(Transform):
"""
- backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+ backend = [TransformBackends.NUMPY]
@deprecated_arg(
name="spatial_channels", new_name="spatial_dims", since="0.8", msg_suffix="please use `spatial_dims` instead."
diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py
index 2c4394a7da4..d45c4431e18 100644
--- a/monai/transforms/utility/dictionary.py
+++ b/monai/transforms/utility/dictionary.py
@@ -301,6 +301,7 @@ def __init__(
meta_key_postfix: str = DEFAULT_POST_FIX,
strict_check: bool = True,
allow_missing_keys: bool = False,
+ channel_dim=None,
) -> None:
"""
Args:
@@ -308,9 +309,13 @@ def __init__(
See also: :py:class:`monai.transforms.compose.MapTransform`
strict_check: whether to raise an error when the meta information is insufficient.
allow_missing_keys: don't raise exception if key is missing.
+ channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.
+ It overrides the `original_channel_dim` from provided MetaTensor input.
+ If the input array doesn't have a channel dim, this value should be ``'no_channel'``.
+ If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.
"""
super().__init__(keys, allow_missing_keys)
- self.adjuster = EnsureChannelFirst(strict_check=strict_check)
+ self.adjuster = EnsureChannelFirst(strict_check=strict_check, channel_dim=channel_dim)
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
@@ -372,6 +377,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
class SplitDimd(MapTransform):
+
+ backend = SplitDim.backend
+
def __init__(
self,
keys: KeysCollection,
@@ -379,6 +387,7 @@ def __init__(
dim: int = 0,
keepdim: bool = True,
update_meta: bool = True,
+ list_output: bool = False,
allow_missing_keys: bool = False,
) -> None:
"""
@@ -394,15 +403,34 @@ def __init__(
dimension will be squeezed.
update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to
reflect the cropped image
+ list_output: it `True`, the output will be a list of dictionaries with the same keys as original.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.output_postfixes = output_postfixes
self.splitter = SplitDim(dim, keepdim, update_meta)
+ self.list_output = list_output
+ if self.list_output is None and self.output_postfixes is not None:
+ raise ValueError("`output_postfixes` should not be provided when `list_output` is `True`.")
- def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
+ def __call__(
+ self, data: Mapping[Hashable, torch.Tensor]
+ ) -> Union[Dict[Hashable, torch.Tensor], List[Dict[Hashable, torch.Tensor]]]:
d = dict(data)
- for key in self.key_iterator(d):
+ all_keys = list(set(self.key_iterator(d)))
+
+ if self.list_output:
+ output = []
+ results = [self.splitter(d[key]) for key in all_keys]
+ for row in zip(*results):
+ new_dict = dict(zip(all_keys, row))
+ # fill in the extra keys with unmodified data
+ for k in set(d.keys()).difference(set(all_keys)):
+ new_dict[k] = deepcopy(d[k])
+ output.append(new_dict)
+ return output
+
+ for key in all_keys:
rets = self.splitter(d[key])
postfixes: Sequence = list(range(len(rets))) if self.output_postfixes is None else self.output_postfixes
if len(postfixes) != len(rets):
@@ -486,7 +514,7 @@ def __init__(
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
- track_meta: Optional[bool] = False,
+ track_meta: Optional[bool] = None,
allow_missing_keys: bool = False,
) -> None:
"""
@@ -497,8 +525,8 @@ def __init__(
device: specify the target device to put the Tensor data.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
- track_meta: whether to convert to `MetaTensor`, default to `False`, output type will be `torch.Tensor`.
- if `None`, use the return value of ``get_track_meta``.
+ track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
+ if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.
allow_missing_keys: don't raise exception if key is missing.
"""
@@ -508,19 +536,19 @@ def __init__(
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
- self.push_transform(d, key)
d[key] = self.converter(d[key])
+ self.push_transform(d, key)
return d
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
+ # Remove the applied transform
+ self.pop_transform(d, key)
# Create inverse transform
inverse_transform = ToNumpy()
# Apply inverse
d[key] = inverse_transform(d[key])
- # Remove the applied transform
- self.pop_transform(d, key)
return d
@@ -758,16 +786,19 @@ class SqueezeDimd(MapTransform):
backend = SqueezeDim.backend
- def __init__(self, keys: KeysCollection, dim: int = 0, allow_missing_keys: bool = False) -> None:
+ def __init__(
+ self, keys: KeysCollection, dim: int = 0, update_meta: bool = True, allow_missing_keys: bool = False
+ ) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
dim: dimension to be squeezed. Default: 0 (the first dimension)
+ update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
- self.converter = SqueezeDim(dim=dim)
+ self.converter = SqueezeDim(dim=dim, update_meta=update_meta)
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py
index ae550e7ce66..e96d906f20e 100644
--- a/monai/transforms/utils.py
+++ b/monai/transforms/utils.py
@@ -46,7 +46,6 @@
PostFix,
PytorchPadMode,
TraceKeys,
- deprecated_arg,
ensure_tuple,
ensure_tuple_rep,
ensure_tuple_size,
@@ -58,9 +57,10 @@
optional_import,
)
from monai.utils.enums import TransformBackends
-from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
+from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
-measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
+measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
+morphology, has_morphology = optional_import("skimage.morphology")
ndimage, _ = optional_import("scipy.ndimage")
cp, has_cp = optional_import("cupy")
cp_ndarray, _ = optional_import("cupy", name="ndarray")
@@ -69,6 +69,7 @@
__all__ = [
"allow_missing_keys_mode",
+ "check_boundaries",
"compute_divisible_spatial_size",
"convert_applied_interp_mode",
"copypaste_arrays",
@@ -86,6 +87,7 @@
"generate_spatial_bounding_box",
"get_extreme_points",
"get_largest_connected_component_mask",
+ "remove_small_objects",
"img_bounds",
"in_bounds",
"is_empty",
@@ -503,11 +505,11 @@ def generate_pos_neg_label_crop_centers(
raise ValueError("No sampling location available.")
if len(fg_indices) == 0 or len(bg_indices) == 0:
+ pos_ratio = 0 if len(fg_indices) == 0 else 1
warnings.warn(
- f"N foreground {len(fg_indices)}, N background {len(bg_indices)},"
- "unable to generate class balanced samples."
+ f"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, "
+ f"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}."
)
- pos_ratio = 0 if fg_indices.size == 0 else 1
for _ in range(num_samples):
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
@@ -586,6 +588,9 @@ def create_grid(
"""
compute a `spatial_size` mesh.
+ - when ``homogeneous=True``, the output shape is (N+1, dim_size_1, dim_size_2, ..., dim_size_N)
+ - when ``homogeneous=False``, the output shape is (N, dim_size_1, dim_size_2, ..., dim_size_N)
+
Args:
spatial_size: spatial size of the grid.
spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid).
@@ -946,7 +951,9 @@ def generate_spatial_bounding_box(
return box_start, box_end
-def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optional[int] = None) -> NdarrayTensor:
+def get_largest_connected_component_mask(
+ img: NdarrayTensor, connectivity: Optional[int] = None, num_components: int = 1
+) -> NdarrayTensor:
"""
Gets the largest connected component mask of an image.
@@ -956,24 +963,87 @@ def get_largest_connected_component_mask(img: NdarrayTensor, connectivity: Optio
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used. for more details:
https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
+ num_components: The number of largest components to preserve.
+ """
+ # use skimage/cucim.skimage and np/cp depending on whether packages are
+ # available and input is non-cpu torch.tensor
+ use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device != torch.device("cpu")
+ if use_cp:
+ img_ = convert_to_cupy(img.short()) # type: ignore
+ label = cucim.skimage.measure.label
+ lib = cp
+ else:
+ if not has_measure:
+ raise RuntimeError("Skimage.measure required.")
+ img_, *_ = convert_data_type(img, np.ndarray)
+ label = measure.label
+ lib = np
+
+ # features will be an image -- 0 for background and then each different
+ # feature will have its own index.
+ features, num_features = label(img_, connectivity=connectivity, return_num=True)
+ # if num features less than max desired, nothing to do.
+ if num_features <= num_components:
+ out = img_.astype(bool)
+ else:
+ # ignore background
+ nonzeros = features[lib.nonzero(features)]
+ # get number voxels per feature (bincount). argsort[::-1] to get indices
+ # of largest components.
+ features_to_keep = lib.argsort(lib.bincount(nonzeros))[::-1]
+ # only keep the first n non-background indices
+ features_to_keep = features_to_keep[:num_components]
+ # generate labelfield. True if in list of features to keep
+ out = lib.isin(features, features_to_keep)
+
+ return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
+
+
+def remove_small_objects(
+ img: NdarrayTensor, min_size: int = 64, connectivity: int = 1, independent_channels: bool = True
+) -> NdarrayTensor:
"""
- if isinstance(img, torch.Tensor) and has_cp and has_cucim:
- x_cupy = monai.transforms.ToCupy()(img.short())
- x_label = cucim.skimage.measure.label(x_cupy, connectivity=connectivity)
- vals, counts = cp.unique(x_label[cp.nonzero(x_label)], return_counts=True)
- comp = x_label == vals[cp.ndarray.argmax(counts)]
- out_tensor = monai.transforms.ToTensor(device=img.device)(comp)
- out_tensor = out_tensor.bool()
+ Use `skimage.morphology.remove_small_objects` to remove small objects from images.
+ See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.
- return out_tensor # type: ignore
+ Data should be one-hotted.
- img_arr = convert_data_type(img, np.ndarray)[0]
- largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype)
- img_arr = measure.label(img_arr, connectivity=connectivity)
- if img_arr.max() != 0:
- largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1)
+ Args:
+ img: image to process. Expected shape: C, H,W,[D]. Expected to only have singleton channel dimension,
+ i.e., not be one-hotted. Converted to type int.
+ min_size: objects smaller than this size are removed.
+ connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
+ Accepted values are ranging from 1 to input.ndim. If ``None``, a full
+ connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
+ documentation.
+ independent_channels: Whether to consider each channel independently.
+ """
+ # if all equal to one value, no need to call skimage
+ if len(unique(img)) == 1:
+ return img
- return convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0]
+ if not has_morphology:
+ raise RuntimeError("Skimage required.")
+
+ img_np: np.ndarray
+ img_np, *_ = convert_data_type(img, np.ndarray)
+
+ # morphology.remove_small_objects assumes them to be independent by default
+ # else, convert to foreground vs background, remove small objects, then convert
+ # back by multiplying the output by the input
+ if not independent_channels:
+ img_np = img_np > 0
+ else:
+ # if binary, convert to boolean, else int
+ img_np = img_np.astype(bool if img_np.max() <= 1 else np.int32)
+
+ out_np = morphology.remove_small_objects(img_np, min_size, connectivity)
+ out, *_ = convert_to_dst_type(out_np, img)
+
+ # convert back by multiplying
+ if not independent_channels:
+ out = img * out # type: ignore
+ return out
def get_unique_labels(
@@ -1364,10 +1434,7 @@ class Fourier:
"""
@staticmethod
- @deprecated_arg(
- name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
- def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor:
+ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
"""
Applies fourier transform and shifts the zero-frequency component to the
center of the spectrum. Only the spatial dimensions get transformed.
@@ -1376,14 +1443,9 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] =
x: Image to transform.
spatial_dims: Number of spatial dimensions.
- .. deprecated:: 0.6.0
- ``n_dims`` is deprecated, use ``spatial_dims`` instead.
-
Returns
k: K-space data.
"""
- if n_dims is not None:
- spatial_dims = n_dims
dims = tuple(range(-spatial_dims, 0))
k: NdarrayOrTensor
if isinstance(x, torch.Tensor):
@@ -1397,9 +1459,6 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] =
return k
@staticmethod
- @deprecated_arg(
- name="n_dims", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead."
- )
def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[int] = None) -> NdarrayOrTensor:
"""
Applies inverse shift and fourier transform. Only the spatial
@@ -1409,14 +1468,9 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: Optional[in
k: K-space data.
spatial_dims: Number of spatial dimensions.
- .. deprecated:: 0.6.0
- ``n_dims`` is deprecated, use ``spatial_dims`` instead.
-
Returns:
x: Tensor in image space.
"""
- if n_dims is not None:
- spatial_dims = n_dims
dims = tuple(range(-spatial_dims, 0))
out: NdarrayOrTensor
if isinstance(k, torch.Tensor):
@@ -1676,5 +1730,65 @@ def sync_meta_info(key, data_dict, t: bool = True):
return d
+def check_boundaries(boundaries) -> None:
+ """
+ Check boundaries for Signal transforms
+ """
+ if not (
+ isinstance(boundaries, Sequence) and len(boundaries) == 2 and all(isinstance(i, float) for i in boundaries)
+ ):
+ raise ValueError("Incompatible values: boundaries needs to be a list of float.")
+
+
+def paste_slices(tup):
+ """
+ given a tuple (pos,w,max_w), return a tuple of slices
+ """
+ pos, w, max_w = tup
+ max_w = max_w.shape[len(max_w.shape) - 1]
+ orig_min = max(pos, 0)
+ orig_max = min(pos + w, max_w)
+ block_min = -min(pos, 0)
+ block_max = max_w - max(pos + w, max_w)
+ block_max = block_max if block_max != 0 else None
+ return slice(orig_min, orig_max), slice(block_min, block_max)
+
+
+def paste(orig, block, loc):
+ """
+ given a location (loc) and an original array (orig), paste a block array into it
+ """
+ loc_zip = zip(loc, block.shape, orig)
+ orig_slices, block_slices = zip(*map(paste_slices, loc_zip))
+
+ orig[:, orig_slices[0]] = block[block_slices[0]]
+
+ if orig.shape[0] == 1:
+ orig = orig.squeeze()
+ return orig
+
+
+def squarepulse(sig, duty: float = 0.5):
+ """
+ compute squarepulse using pytorch
+ equivalent to numpy implementation from
+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.square.html
+ """
+ t, w = convert_to_tensor(sig), convert_to_tensor(duty)
+ w = convert_to_tensor(w)
+ t = convert_to_tensor(t)
+
+ y = torch.zeros(t.shape)
+
+ mask1 = (w > 1) | (w < 0)
+
+ tmod = torch.remainder(t, 2 * torch.pi)
+ mask2 = (~mask1) & (tmod < w * 2 * torch.pi)
+ y[mask2] = 1
+ mask3 = (~mask1) & (~mask2)
+ y[mask3] = -1
+ return y
+
+
if __name__ == "__main__":
print_transform_backends()
diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py
index 554c3752d97..c0d45d55627 100644
--- a/monai/transforms/utils_create_transform_ims.py
+++ b/monai/transforms/utils_create_transform_ims.py
@@ -92,6 +92,7 @@
HistogramNormalize,
KSpaceSpikeNoise,
MaskIntensity,
+ MedianSmooth,
NormalizeIntensity,
RandAdjustContrast,
RandBiasField,
@@ -123,6 +124,7 @@
HistogramNormalized,
KSpaceSpikeNoised,
MaskIntensityd,
+ MedianSmoothD,
NormalizeIntensityd,
RandAdjustContrastd,
RandBiasFieldd,
@@ -145,8 +147,14 @@
StdShiftIntensityd,
ThresholdIntensityd,
)
-from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour
-from monai.transforms.post.dictionary import AsDiscreted, KeepLargestConnectedComponentd, LabelFilterd, LabelToContourd
+from monai.transforms.post.array import KeepLargestConnectedComponent, LabelFilter, LabelToContour, RemoveSmallObjects
+from monai.transforms.post.dictionary import (
+ AsDiscreted,
+ KeepLargestConnectedComponentd,
+ LabelFilterd,
+ LabelToContourd,
+ RemoveSmallObjectsd,
+)
from monai.transforms.smooth_field.array import (
RandSmoothDeform,
RandSmoothFieldAdjustContrast,
@@ -178,6 +186,7 @@
Spacingd,
)
from monai.utils.enums import CommonKeys
+from monai.utils.misc import MONAIEnvVars
from monai.utils.module import optional_import
if TYPE_CHECKING:
@@ -195,7 +204,7 @@ def get_data(keys):
Use MarsAtlas as it only contains 1 image for quick download and
that image is parcellated.
"""
- cache_dir = os.environ.get("MONAI_DATA_DIRECTORY") or tempfile.mkdtemp()
+ cache_dir = MONAIEnvVars.data_dir() or tempfile.mkdtemp()
fname = "MarsAtlas-MNI-Colin27.zip"
url = "https://www.dropbox.com/s/ndz8qtqblkciole/" + fname + "?dl=1"
out_path = os.path.join(cache_dir, "MarsAtlas-MNI-Colin27")
@@ -420,7 +429,7 @@ def create_transform_im(
seed = seed + 1 if isinstance(transform, MapTransform) else seed
transform.set_random_state(seed)
- out_dir = os.environ.get("MONAI_DOC_IMAGES")
+ out_dir = MONAIEnvVars.doc_images()
if out_dir is None:
raise RuntimeError(
"Please git clone https://github.com/Project-MONAI/DocImages"
@@ -465,7 +474,7 @@ def create_transform_im(
create_transform_im(RandFlipd, dict(keys=keys, prob=1, spatial_axis=2), data)
create_transform_im(Flip, dict(spatial_axis=1), data)
create_transform_im(Flipd, dict(keys=keys, spatial_axis=2), data)
- create_transform_im(Orientation, dict(axcodes="RPI", image_only=True), data)
+ create_transform_im(Orientation, dict(axcodes="RPI"), data)
create_transform_im(Orientationd, dict(keys=keys, axcodes="RPI"), data)
create_transform_im(
Rand3DElastic, dict(prob=1.0, sigma_range=(1, 2), magnitude_range=(0.5, 0.5), shear_range=(1, 1, 1)), data
@@ -538,11 +547,7 @@ def create_transform_im(
create_transform_im(KSpaceSpikeNoise, dict(loc=(100, 100, 100), k_intensity=13), data)
create_transform_im(KSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, loc=(100, 100, 100), k_intensity=13), data)
create_transform_im(RandKSpaceSpikeNoise, dict(prob=1, intensity_range=(10, 13)), data)
- create_transform_im(
- RandKSpaceSpikeNoised,
- dict(keys=CommonKeys.IMAGE, global_prob=1, prob=1, common_sampling=True, intensity_range=(13, 15)),
- data,
- )
+ create_transform_im(RandKSpaceSpikeNoised, dict(keys=CommonKeys.IMAGE, prob=1, intensity_range=(13, 15)), data)
create_transform_im(RandRicianNoise, dict(prob=1.0, mean=1, std=0.5), data)
create_transform_im(RandRicianNoised, dict(keys=CommonKeys.IMAGE, prob=1.0, mean=1, std=0.5), data)
create_transform_im(SavitzkyGolaySmooth, dict(window_length=5, order=1), data)
@@ -594,6 +599,8 @@ def create_transform_im(
create_transform_im(ForegroundMaskd, dict(keys=CommonKeys.IMAGE, invert=True), data)
create_transform_im(GaussianSmooth, dict(sigma=2), data)
create_transform_im(GaussianSmoothd, dict(keys=CommonKeys.IMAGE, sigma=2), data)
+ create_transform_im(MedianSmooth, dict(radius=3), data)
+ create_transform_im(MedianSmoothD, dict(keys=keys, radius=1), data)
create_transform_im(RandGaussianSmooth, dict(prob=1.0, sigma_x=(1, 2)), data)
create_transform_im(RandGaussianSmoothd, dict(keys=CommonKeys.IMAGE, prob=1.0, sigma_x=(1, 2)), data)
create_transform_im(GaussianSharpen, dict(), GaussianSmoothd(CommonKeys.IMAGE, 2)(data))
@@ -677,7 +684,7 @@ def create_transform_im(
)
create_transform_im(LabelToContour, dict(), data, is_post=True)
create_transform_im(LabelToContourd, dict(keys=CommonKeys.LABEL), data, is_post=True)
- create_transform_im(Spacing, dict(pixdim=(5, 5, 5), image_only=True), data)
+ create_transform_im(Spacing, dict(pixdim=(5, 5, 5)), data)
create_transform_im(Spacingd, dict(keys=keys, pixdim=(5, 5, 5), mode=["bilinear", "nearest"]), data)
create_transform_im(RandAxisFlip, dict(prob=1), data)
create_transform_im(RandAxisFlipd, dict(keys=keys, prob=1), data)
@@ -689,6 +696,10 @@ def create_transform_im(
create_transform_im(
KeepLargestConnectedComponentd, dict(keys=CommonKeys.LABEL, applied_labels=1), data_binary, is_post=True, ndim=2
)
+ create_transform_im(RemoveSmallObjects, dict(min_size=100), data_binary, is_post=True, ndim=2)
+ create_transform_im(
+ RemoveSmallObjectsd, dict(keys=CommonKeys.LABEL, min_size=100), data_binary, is_post=True, ndim=2
+ )
create_transform_im(
GridDistortion, dict(num_cells=3, distort_steps=[(1.5,) * 4] * 3, mode="nearest", padding_mode="zeros"), data
)
diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py
index af9d51efb62..aef4a32fe3d 100644
--- a/monai/transforms/utils_pytorch_numpy_unification.py
+++ b/monai/transforms/utils_pytorch_numpy_unification.py
@@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Optional, Sequence, Union
+from typing import Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -42,6 +42,11 @@
"stack",
"mode",
"unique",
+ "max",
+ "min",
+ "median",
+ "mean",
+ "std",
]
@@ -87,8 +92,8 @@ def percentile(
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.
Args:
- x: input data
- q: percentile to compute (should in range 0 <= q <= 100)
+ x: input data.
+ q: percentile to compute (should in range 0 <= q <= 100).
dim: the dim along which the percentiles are computed. default is to compute the percentile
along a flattened version of the array.
keepdim: whether the output data has dim retained or not.
@@ -98,16 +103,16 @@ def percentile(
Returns:
Resulting value (scalar)
"""
- if np.isscalar(q):
- if not 0 <= q <= 100: # type: ignore
- raise ValueError
- elif any(q < 0) or any(q > 100):
- raise ValueError
+ q_np = convert_data_type(q, output_type=np.ndarray, wrap_sequence=True)[0]
+ if ((q_np < 0) | (q_np > 100)).any():
+ raise ValueError(f"q values must be in [0, 100], got values: {q}.")
result: Union[NdarrayOrTensor, float, int]
- if isinstance(x, np.ndarray):
- result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs)
+ if isinstance(x, np.ndarray) or (isinstance(x, torch.Tensor) and torch.numel(x) > 1_000_000): # pytorch#64947
+ _x = convert_data_type(x, output_type=np.ndarray)[0]
+ result = np.percentile(_x, q_np, axis=dim, keepdims=keepdim, **kwargs)
+ result = convert_to_dst_type(result, x)[0]
else:
- q = convert_to_dst_type(q / 100.0, x)[0]
+ q = convert_to_dst_type(q_np / 100.0, x)[0]
result = torch.quantile(x, q, dim=dim, keepdim=keepdim)
return result
@@ -136,7 +141,7 @@ def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.nonzero` with equivalent implementation for torch.
Args:
- x: array/tensor
+ x: array/tensor.
Returns:
Index unravelled for given shape
@@ -170,8 +175,8 @@ def unravel_index(idx, shape) -> NdarrayOrTensor:
"""`np.unravel_index` with equivalent implementation for torch.
Args:
- idx: index to unravel
- shape: shape of array/tensor
+ idx: index to unravel.
+ shape: shape of array/tensor.
Returns:
Index unravelled for given shape
@@ -189,8 +194,8 @@ def unravel_indices(idx, shape) -> NdarrayOrTensor:
"""Computing unravel coordinates from indices.
Args:
- idx: a sequence of indices to unravel
- shape: shape of array/tensor
+ idx: a sequence of indices to unravel.
+ shape: shape of array/tensor.
Returns:
Stacked indices unravelled for given shape
@@ -203,7 +208,7 @@ def ravel(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.ravel` with equivalent implementation for torch.
Args:
- x: array/tensor to ravel
+ x: array/tensor to ravel.
Returns:
Return a contiguous flattened array/tensor.
@@ -221,8 +226,8 @@ def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrT
For pytorch, convert to boolean for compatibility with older versions.
Args:
- x: input array/tensor
- axis: axis to perform `any` over
+ x: input array/tensor.
+ axis: axis to perform `any` over.
Returns:
Return a contiguous flattened array/tensor.
@@ -245,8 +250,8 @@ def maximum(a: NdarrayOrTensor, b: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.maximum` with equivalent implementation for torch.
Args:
- a: first array/tensor
- b: second array/tensor
+ a: first array/tensor.
+ b: second array/tensor.
Returns:
Element-wise maximum between two arrays/tensors.
@@ -285,7 +290,7 @@ def cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor:
def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.isfinite` with equivalent implementation for torch."""
if not isinstance(x, torch.Tensor):
- return np.isfinite(x)
+ return np.isfinite(x) # type: ignore
return torch.isfinite(x)
@@ -329,11 +334,11 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.isnan` with equivalent implementation for torch.
Args:
- x: array/tensor
+ x: array/tensor.
"""
if isinstance(x, np.ndarray):
- return np.isnan(x)
+ return np.isnan(x) # type: ignore
return torch.isnan(x)
@@ -341,7 +346,7 @@ def ascontiguousarray(x: NdarrayTensor, **kwargs) -> NdarrayOrTensor:
"""`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).
Args:
- x: array/tensor
+ x: array/tensor.
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.
@@ -359,8 +364,8 @@ def stack(x: Sequence[NdarrayTensor], dim: int) -> NdarrayTensor:
"""`np.stack` with equivalent implementation for torch.
Args:
- x: array/tensor
- dim: dimension along which to perform the stack (referred to as `axis` by numpy)
+ x: array/tensor.
+ dim: dimension along which to perform the stack (referred to as `axis` by numpy).
"""
if isinstance(x[0], np.ndarray):
return np.stack(x, dim) # type: ignore
@@ -371,8 +376,8 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor
"""`torch.mode` with equivalent implementation for numpy.
Args:
- x: array/tensor
- dim: dimension along which to perform `mode` (referred to as `axis` by numpy)
+ x: array/tensor.
+ dim: dimension along which to perform `mode` (referred to as `axis` by numpy).
to_long: convert input to long before performing mode.
"""
dtype = torch.int64 if to_long else None
@@ -382,21 +387,154 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor
return o
-def unique(x: NdarrayTensor) -> NdarrayTensor:
+def unique(x: NdarrayTensor, **kwargs) -> NdarrayTensor:
"""`torch.unique` with equivalent implementation for numpy.
Args:
- x: array/tensor
+ x: array/tensor.
"""
- return torch.unique(x) if isinstance(x, torch.Tensor) else np.unique(x) # type: ignore
+ return np.unique(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.unique(x, **kwargs) # type: ignore
def linalg_inv(x: NdarrayTensor) -> NdarrayTensor:
"""`torch.linalg.inv` with equivalent implementation for numpy.
Args:
- x: array/tensor
+ x: array/tensor.
"""
if isinstance(x, torch.Tensor) and hasattr(torch, "inverse"): # pytorch 1.7.0
return torch.inverse(x) # type: ignore
return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x) # type: ignore
+
+
+def max(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor:
+ """`torch.max` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ the maximum of x.
+
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.max(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.max(x, **kwargs) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.max(x, axis=dim, **kwargs)
+ else:
+ ret = torch.max(x, int(dim), **kwargs) # type: ignore
+
+ return ret
+
+
+def mean(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor:
+ """`torch.mean` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ the mean of x
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.mean(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.mean(x, **kwargs) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.mean(x, axis=dim, **kwargs)
+ else:
+ ret = torch.mean(x, int(dim), **kwargs) # type: ignore
+
+ return ret
+
+
+def median(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor:
+ """`torch.median` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns
+ the median of x.
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.median(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.median(x, **kwargs) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.median(x, axis=dim, **kwargs)
+ else:
+ ret = torch.median(x, int(dim), **kwargs) # type: ignore
+
+ return ret
+
+
+def min(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor:
+ """`torch.min` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ the minimum of x.
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.min(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.min(x, **kwargs) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.min(x, axis=dim, **kwargs)
+ else:
+ ret = torch.min(x, int(dim), **kwargs) # type: ignore
+
+ return ret
+
+
+def std(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, unbiased: bool = False) -> NdarrayTensor:
+ """`torch.std` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ the standard deviation of x.
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.std(x) if isinstance(x, (np.ndarray, list)) else torch.std(x, unbiased) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.std(x, axis=dim)
+ else:
+ ret = torch.std(x, int(dim), unbiased) # type: ignore
+
+ return ret
+
+
+def sum(x: NdarrayTensor, dim: Optional[Union[int, Tuple]] = None, **kwargs) -> NdarrayTensor:
+ """`torch.sum` with equivalent implementation for numpy
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ the sum of x.
+ """
+
+ ret: NdarrayTensor
+ if dim is None:
+ ret = np.sum(x, **kwargs) if isinstance(x, (np.ndarray, list)) else torch.sum(x, **kwargs) # type: ignore
+ else:
+ if isinstance(x, (np.ndarray, list)):
+ ret = np.sum(x, axis=dim, **kwargs)
+ else:
+ ret = torch.sum(x, int(dim), **kwargs) # type: ignore
+
+ return ret
diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py
index 4cdeef5adc0..21d3621090a 100644
--- a/monai/utils/__init__.py
+++ b/monai/utils/__init__.py
@@ -19,26 +19,34 @@
BlendMode,
BoxModeName,
ChannelMatching,
+ ColorOrder,
CommonKeys,
DiceCEReduction,
+ EngineStatsKeys,
FastMRIKeys,
ForwardMode,
+ GanKeys,
GridPatchSort,
GridSampleMode,
GridSamplePadMode,
+ HoVerNetBranch,
+ HoVerNetMode,
InterpolateMode,
InverseKeys,
JITMetadataKeys,
+ LazyAttr,
LossReduction,
MetaKeys,
Method,
MetricReduction,
+ NdimageMode,
NumpyPadMode,
PostFix,
ProbMapKeys,
PytorchPadMode,
SkipMode,
SpaceKeys,
+ SplineMode,
StrEnum,
TraceKeys,
TransformBackends,
@@ -50,6 +58,7 @@
from .misc import (
MAX_SEED,
ImageMetaKey,
+ MONAIEnvVars,
check_parent_dir,
copy_to_device,
ensure_tuple,
@@ -69,11 +78,14 @@
save_obj,
set_determinism,
star_zip_with,
+ str2bool,
+ str2list,
zip_with,
)
from .module import (
InvalidPyTorchVersionError,
OptionalImportError,
+ allow_missing_reference,
damerau_levenshtein_distance,
exact_version,
export,
@@ -87,6 +99,8 @@
optional_import,
pytorch_after,
require_pkg,
+ run_debug,
+ run_eval,
version_leq,
)
from .nvtx import Range
diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py
index 0ae79e26ff8..1a63c3aba87 100644
--- a/monai/utils/aliases.py
+++ b/monai/utils/aliases.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
This module is written for configurable workflow, not currently in use.
"""
diff --git a/monai/utils/deprecate_utils.py b/monai/utils/deprecate_utils.py
index 209772a3ec3..68f2d6e46d7 100644
--- a/monai/utils/deprecate_utils.py
+++ b/monai/utils/deprecate_utils.py
@@ -57,7 +57,7 @@ def deprecated(
Args:
since: version at which the definition was marked deprecated but not removed.
- removed: version at which the definition was removed and no longer usable.
+ removed: version at which the definition was/will be removed and no longer usable.
msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
warning_category: a warning category class, defaults to `FutureWarning`.
@@ -66,9 +66,6 @@ def deprecated(
Decorated definition which warns or raises exception when used
"""
- # if version_val.startswith("0+"):
- # # version unknown, set version_val to a large value (assuming the latest version)
- # version_val = f"{sys.maxsize}"
if since is not None and removed is not None and not version_leq(since, removed):
raise ValueError(f"since must be less or equal to removed, got since={since}, removed={removed}.")
is_not_yet_deprecated = since is not None and version_val != since and version_leq(version_val, since)
@@ -147,7 +144,7 @@ def deprecated_arg(
Args:
name: name of position or keyword argument to mark as deprecated.
since: version at which the argument was marked deprecated but not removed.
- removed: version at which the argument was removed and no longer usable.
+ removed: version at which the argument was/will be removed and no longer usable.
msg_suffix: message appended to warning/exception detailing reasons for deprecation and what to use instead.
version_val: (used for testing) version to compare since and removed against, default is MONAI version.
new_name: name of position or keyword argument to replace the deprecated argument.
diff --git a/monai/utils/enums.py b/monai/utils/enums.py
index a6d9a23309a..4fd9bea5577 100644
--- a/monai/utils/enums.py
+++ b/monai/utils/enums.py
@@ -19,10 +19,12 @@
"StrEnum",
"NumpyPadMode",
"GridSampleMode",
+ "SplineMode",
"InterpolateMode",
"UpsampleMode",
"BlendMode",
"PytorchPadMode",
+ "NdimageMode",
"GridSamplePadMode",
"Average",
"MetricReduction",
@@ -35,6 +37,7 @@
"TraceKeys",
"InverseKeys",
"CommonKeys",
+ "GanKeys",
"PostFix",
"ForwardMode",
"TransformBackends",
@@ -43,6 +46,15 @@
"FastMRIKeys",
"SpaceKeys",
"MetaKeys",
+ "ColorOrder",
+ "EngineStatsKeys",
+ "DataStatsKeys",
+ "ImageStatsKeys",
+ "LabelStatsKeys",
+ "AlgoEnsembleKeys",
+ "HoVerNetMode",
+ "HoVerNetBranch",
+ "LazyAttr",
]
@@ -89,6 +101,22 @@ class NumpyPadMode(StrEnum):
EMPTY = "empty"
+class NdimageMode(StrEnum):
+ """
+ The available options determine how the input array is extended beyond its boundaries when interpolating.
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ """
+
+ REFLECT = "reflect"
+ GRID_MIRROR = "grid-mirror"
+ CONSTANT = "constant"
+ GRID_CONSTANT = "grid-constant"
+ NEAREST = "nearest"
+ MIRROR = "mirror"
+ GRID_WRAP = "grid-wrap"
+ WRAP = "wrap"
+
+
class GridSampleMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
@@ -107,6 +135,21 @@ class GridSampleMode(StrEnum):
BICUBIC = "bicubic"
+class SplineMode(StrEnum):
+ """
+ Order of spline interpolation.
+
+ See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
+ """
+
+ ZERO = 0
+ ONE = 1
+ TWO = 2
+ THREE = 3
+ FOUR = 4
+ FIVE = 5
+
+
class InterpolateMode(StrEnum):
"""
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
@@ -127,6 +170,7 @@ class UpsampleMode(StrEnum):
"""
DECONV = "deconv"
+ DECONVGROUP = "deconvgroup"
NONTRAINABLE = "nontrainable" # e.g. using torch.nn.Upsample
PIXELSHUFFLE = "pixelshuffle"
@@ -306,6 +350,19 @@ class CommonKeys(StrEnum):
METADATA = "metadata"
+class GanKeys(StrEnum):
+ """
+ A set of common keys for generative adversarial networks.
+
+ """
+
+ REALS = "reals"
+ FAKES = "fakes"
+ LATENTS = "latents"
+ GLOSS = "g_loss"
+ DLOSS = "d_loss"
+
+
class PostFix(StrEnum):
"""Post-fixes."""
@@ -328,11 +385,16 @@ def transforms(key: Optional[str] = None):
class TransformBackends(StrEnum):
"""
- Transform backends.
+ Transform backends. Most of `monai.transforms` components first converts the input data into ``torch.Tensor`` or
+ ``monai.data.MetaTensor``. Internally, some transforms are made by converting the data into ``numpy.array`` or
+ ``cupy.array`` and use the underlying transform backend API to achieve the actual output array and
+ converting back to ``Tensor``/``MetaTensor``. Transforms with more than one backend indicate the that they may
+ convert the input data types to accomodate the underlying API.
"""
TORCH = "torch"
NUMPY = "numpy"
+ CUPY = "cupy"
class JITMetadataKeys(StrEnum):
@@ -411,10 +473,10 @@ class WSIPatchKeys(StrEnum):
The keys to be used for metadata of patches extracted from whole slide images
"""
- LOCATION = "patch_location"
- LEVEL = "patch_level"
- SIZE = "patch_size"
- COUNT = "num_patches"
+ LOCATION = "location"
+ LEVEL = "level"
+ SIZE = "size"
+ COUNT = "count"
PATH = "path"
@@ -453,3 +515,118 @@ class MetaKeys(StrEnum):
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel"
+
+
+class ColorOrder(StrEnum):
+ """
+ Enums for color order. Expand as necessary.
+ """
+
+ RGB = "RGB"
+ BGR = "BGR"
+
+
+class EngineStatsKeys(StrEnum):
+ """
+ Default keys for the statistics of trainer and evaluator engines.
+
+ """
+
+ RANK = "rank"
+ CURRENT_ITERATION = "current_iteration"
+ CURRENT_EPOCH = "current_epoch"
+ TOTAL_EPOCHS = "total_epochs"
+ TOTAL_ITERATIONS = "total_iterations"
+ BEST_VALIDATION_EPOCH = "best_validation_epoch"
+ BEST_VALIDATION_METRIC = "best_validation_metric"
+
+
+class DataStatsKeys(StrEnum):
+ """
+ Defaults keys for dataset statistical analysis modules
+
+ """
+
+ SUMMARY = "stats_summary"
+ BY_CASE = "stats_by_cases"
+ BY_CASE_IMAGE_PATH = "image_filepath"
+ BY_CASE_LABEL_PATH = "label_filepath"
+ IMAGE_STATS = "image_stats"
+ FG_IMAGE_STATS = "image_foreground_stats"
+ LABEL_STATS = "label_stats"
+ IMAGE_HISTOGRAM = "image_histogram"
+
+
+class ImageStatsKeys(StrEnum):
+ """
+ Defaults keys for dataset statistical analysis image modules
+
+ """
+
+ SHAPE = "shape"
+ CHANNELS = "channels"
+ CROPPED_SHAPE = "cropped_shape"
+ SPACING = "spacing"
+ INTENSITY = "intensity"
+ HISTOGRAM = "histogram"
+
+
+class LabelStatsKeys(StrEnum):
+ """
+ Defaults keys for dataset statistical analysis label modules
+
+ """
+
+ LABEL_UID = "labels"
+ PIXEL_PCT = "foreground_percentage"
+ IMAGE_INTST = "image_intensity"
+ LABEL = "label"
+ LABEL_SHAPE = "shape"
+ LABEL_NCOMP = "ncomponents"
+
+
+class AlgoEnsembleKeys(StrEnum):
+ """
+ Default keys for Mixed Ensemble
+ """
+
+ ID = "identifier"
+ ALGO = "infer_algo"
+ SCORE = "best_metric"
+
+
+class HoVerNetMode(StrEnum):
+ """
+ Modes for HoVerNet model:
+ `FAST`: a faster implementation (than original)
+ `ORIGINAL`: the original implementation
+ """
+
+ FAST = "FAST"
+ ORIGINAL = "ORIGINAL"
+
+
+class HoVerNetBranch(StrEnum):
+ """
+ Three branches of HoVerNet model, which results in three outputs:
+ `HV` is horizontal and vertical gradient map of each nucleus (regression),
+ `NP` is the pixel prediction of all nuclei (segmentation), and
+ `NC` is the type of each nucleus (classification).
+ """
+
+ HV = "horizontal_vertical"
+ NP = "nucleus_prediction"
+ NC = "type_prediction"
+
+
+class LazyAttr(StrEnum):
+ """
+ MetaTensor with pending operations requires some key attributes tracked especially when the primary array
+ is not up-to-date due to lazy evaluation.
+ This class specifies the set of key attributes to be tracked for each MetaTensor.
+ """
+
+ SHAPE = "lazy_shape" # spatial shape
+ AFFINE = "lazy_affine"
+ PADDING_MODE = "lazy_padding_mode"
+ INTERP_MODE = "lazy_interpolation_mode"
diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py
index 366d11ebd81..f9eb00aa02f 100644
--- a/monai/utils/jupyter_utils.py
+++ b/monai/utils/jupyter_utils.py
@@ -8,12 +8,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using
Matplotlib produce common plots for metrics and images.
"""
+import copy
from enum import Enum
from threading import RLock, Thread
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
@@ -340,7 +340,7 @@ def status_dict(self) -> Dict[str, str]:
def status(self) -> str:
"""Returns a status string for the current state of the engine."""
- stats = self.status_dict
+ stats = copy.deepcopy(self.status_dict)
msgs = [stats.pop(StatusMembers.STATUS.value), "Iters: " + str(stats.pop(StatusMembers.ITERS.value, 0))]
diff --git a/monai/utils/misc.py b/monai/utils/misc.py
index fc38dc5056d..1071f37840a 100644
--- a/monai/utils/misc.py
+++ b/monai/utils/misc.py
@@ -21,7 +21,7 @@
from collections.abc import Iterable
from distutils.util import strtobool
from pathlib import Path
-from typing import Any, Callable, Optional, Sequence, Tuple, Union, cast
+from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast
import numpy as np
import torch
@@ -46,12 +46,16 @@
"list_to_dict",
"MAX_SEED",
"copy_to_device",
+ "str2bool",
+ "str2list",
+ "MONAIEnvVars",
"ImageMetaKey",
"is_module_ver_at_least",
"has_option",
"sample_slices",
"check_parent_dir",
"save_obj",
+ "label_union",
]
_seed = None
@@ -360,6 +364,92 @@ def copy_to_device(
return obj
+def str2bool(value: Union[str, bool], default: bool = False, raise_exc: bool = True) -> bool:
+ """
+ Convert a string to a boolean. Case insensitive.
+ True: yes, true, t, y, 1. False: no, false, f, n, 0.
+
+ Args:
+ value: string to be converted to a boolean. If value is a bool already, simply return it.
+ raise_exc: if value not in tuples of expected true or false inputs,
+ should we raise an exception? If not, return `default`.
+ Raises
+ ValueError: value not in tuples of expected true or false inputs and
+ `raise_exc` is `True`.
+ Useful with argparse, for example:
+ parser.add_argument("--convert", default=False, type=str2bool)
+ python mycode.py --convert=True
+ """
+
+ if isinstance(value, bool):
+ return value
+
+ true_set = ("yes", "true", "t", "y", "1")
+ false_set = ("no", "false", "f", "n", "0")
+
+ if isinstance(value, str):
+ value = value.lower()
+ if value in true_set:
+ return True
+ if value in false_set:
+ return False
+
+ if raise_exc:
+ raise ValueError(f"Got \"{value}\", expected a value from: {', '.join(true_set + false_set)}")
+ return default
+
+
+def str2list(value: Optional[Union[str, list]], raise_exc: bool = True) -> Optional[list]:
+ """
+ Convert a string to a list. Useful with argparse commandline arguments:
+ parser.add_argument("--blocks", default=[1,2,3], type=str2list)
+ python mycode.py --blocks=1,2,2,4
+
+ Args:
+ value: string (comma separated) to be converted to a list
+ raise_exc: if not possible to convert to a list, raise an exception
+ Raises
+ ValueError: value not a string or list or not possible to convert
+ """
+
+ if value is None:
+ return None
+ elif isinstance(value, list):
+ return value
+ elif isinstance(value, str):
+ v = value.split(",")
+ for i in range(len(v)):
+ try:
+ a = literal_eval(v[i].strip()) # attempt to convert
+ v[i] = a
+ except Exception:
+ pass
+ return v
+ elif raise_exc:
+ raise ValueError(f'Unable to convert "{value}", expected a comma-separated str, e.g. 1,2,3')
+
+ return None
+
+
+class MONAIEnvVars:
+ """
+ Environment variables used by MONAI.
+ """
+
+ @staticmethod
+ def data_dir() -> Optional[str]:
+ return os.environ.get("MONAI_DATA_DIRECTORY")
+
+ @staticmethod
+ def debug() -> bool:
+ val = os.environ.get("MONAI_DEBUG", False)
+ return val if isinstance(val, bool) else str2bool(val)
+
+ @staticmethod
+ def doc_images() -> Optional[str]:
+ return os.environ.get("MONAI_DOC_IMAGES")
+
+
class ImageMetaKey:
"""
Common key names in the metadata header of images
@@ -471,3 +561,26 @@ def save_obj(
shutil.move(str(temp_path), path)
except PermissionError: # project-monai/monai issue #3613
pass
+
+
+def label_union(x: List) -> List:
+ """
+ Compute the union of class IDs in label and generate a list to include all class IDs
+ Args:
+ x: a list of numbers (for example, class_IDs)
+
+ Returns
+ a list showing the union (the union the class IDs)
+ """
+ return list(set.union(set(np.array(x).tolist())))
+
+
+def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs):
+ """
+ Compute the lab from the probability of predicted feature maps
+
+ Args:
+ sigmoid: If the sigmoid function should be used.
+ threshold: threshold value to activate the sigmoid function.
+ """
+ return torch.argmax(x, **kwargs) if not sigmoid else (x > threshold).int()
diff --git a/monai/utils/module.py b/monai/utils/module.py
index 747c985af71..435b07fcac9 100644
--- a/monai/utils/module.py
+++ b/monai/utils/module.py
@@ -25,6 +25,14 @@
import torch
+# bundle config system flags
+# set MONAI_EVAL_EXPR=1 to use 'eval', default value: run_eval=True
+run_eval = os.environ.get("MONAI_EVAL_EXPR", "1") != "0"
+# set MONAI_DEBUG_CONFIG=1 to run in debug mode, default value: run_debug=False
+run_debug = os.environ.get("MONAI_DEBUG_CONFIG", "0") != "0"
+# set MONAI_ALLOW_MISSING_REFERENCE=1 to allow missing references, default value: allow_missing_reference=False
+allow_missing_reference = os.environ.get("MONAI_ALLOW_MISSING_REFERENCE", "0") != "0"
+
OPTIONAL_IMPORT_MSG_FMT = "{}"
__all__ = [
@@ -193,6 +201,12 @@ def load_submodules(basemod, load_all: bool = True, exclude_pattern: str = "(.*[
submodules.append(mod)
except OptionalImportError:
pass # could not import the optional deps., they are ignored
+ except ImportError as e:
+ msg = (
+ "\nMultiple versions of MONAI may have been installed?\n"
+ "Please see the installation guide: https://docs.monai.io/en/stable/installation.html\n"
+ ) # issue project-monai/monai#5193
+ raise type(e)(f"{e}\n{msg}").with_traceback(e.__traceback__) from e # raise with modified message
return submodules, err_mod
@@ -214,6 +228,14 @@ def instantiate(path: str, **kwargs):
if component is None:
raise ModuleNotFoundError(f"Cannot locate class or function path: '{path}'.")
try:
+ if kwargs.pop("_debug_", False) or run_debug:
+ warnings.warn(
+ f"\n\npdb: instantiating component={component}\n"
+ f"See also Debugger commands documentation: https://docs.python.org/3/library/pdb.html\n"
+ )
+ import pdb
+
+ pdb.set_trace()
if isclass(component):
return component(**kwargs)
# support regular function, static method and class method
@@ -237,7 +259,7 @@ def get_full_type_name(typeobj):
return module + "." + typeobj.__name__
-def min_version(the_module, min_version_str: str = "") -> bool:
+def min_version(the_module, min_version_str: str = "", *_args) -> bool:
"""
Convert version strings into tuples of int and compare them.
@@ -252,7 +274,7 @@ def min_version(the_module, min_version_str: str = "") -> bool:
return mod_version >= required
-def exact_version(the_module, version_str: str = "") -> bool:
+def exact_version(the_module, version_str: str = "", *_args) -> bool:
"""
Returns True if the module's __version__ matches version_str
"""
@@ -287,6 +309,7 @@ def optional_import(
descriptor: str = OPTIONAL_IMPORT_MSG_FMT,
version_args=None,
allow_namespace_pkg: bool = False,
+ as_type: str = "default",
) -> Tuple[Any, bool]:
"""
Imports an optional module specified by `module` string.
@@ -301,6 +324,10 @@ def optional_import(
descriptor: a format string for the final error message when using a not imported module.
version_args: additional parameters to the version checker.
allow_namespace_pkg: whether importing a namespace package is allowed. Defaults to False.
+ as_type: there are cases where the optionally imported object is used as
+ a base class, or a decorator, the exceptions should raise accordingly. The current supported values
+ are "default" (call once to raise), "decorator" (call the constructor and the second call to raise),
+ and anything else will return a lazy class that can be used as a base class (call the constructor to raise).
Returns:
The imported module and a boolean flag indicating whether the import is successful.
@@ -387,7 +414,22 @@ def __call__(self, *_args, **_kwargs):
"""
raise self._exception
- return _LazyRaise(), False
+ def __getitem__(self, item):
+ raise self._exception
+
+ def __iter__(self):
+ raise self._exception
+
+ if as_type == "default":
+ return _LazyRaise(), False
+
+ class _LazyCls(_LazyRaise):
+ def __init__(self, *_args, **kwargs):
+ super().__init__()
+ if not as_type.startswith("decorator"):
+ raise self._exception
+
+ return _LazyCls, False
def require_pkg(
diff --git a/monai/utils/profiling.py b/monai/utils/profiling.py
index 291e58d57fa..3cf558a0389 100644
--- a/monai/utils/profiling.py
+++ b/monai/utils/profiling.py
@@ -162,7 +162,7 @@ class WorkflowProfiler:
The tracing functionality uses a selector to choose which calls to trace, since tracing all calls induces
infinite loops and would be terribly slow even if not. This selector is a callable accepting a `call` trace
- frame and returns True if the call should be traced. The dedault is `select_transform_call` which will return
+ frame and returns True if the call should be traced. The default is `select_transform_call` which will return
True for `Transform.__call__` calls only.
Example showing use of all profiling functions:
@@ -369,7 +369,7 @@ def get_times_summary(self, times_in_s=True):
return result
def get_times_summary_pd(self, times_in_s=True):
- """Returns the same informatoin as `get_times_summary` but in a Pandas DataFrame."""
+ """Returns the same information as `get_times_summary` but in a Pandas DataFrame."""
import pandas as pd
summ = self.get_times_summary(times_in_s)
diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py
index 9c9fb1a4b23..b21edf6496b 100644
--- a/monai/utils/type_conversion.py
+++ b/monai/utils/type_conversion.py
@@ -36,6 +36,9 @@
"convert_to_dst_type",
]
+# conversion map for types unsupported by torch.as_tensor
+UNSUPPORTED_TYPES = {np.dtype("uint16"): np.int32, np.dtype("uint32"): np.int64, np.dtype("uint64"): np.int64}
+
def get_numpy_dtype_from_string(dtype: str) -> np.dtype:
"""Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`)."""
@@ -98,7 +101,7 @@ def get_dtype(data: Any):
def convert_to_tensor(
data,
- dtype: Optional[torch.dtype] = None,
+ dtype: Union[DtypeLike, torch.dtype] = None,
device: Union[None, str, torch.device] = None,
wrap_sequence: bool = False,
track_meta: bool = False,
@@ -123,6 +126,10 @@ def convert_to_tensor(
def _convert_tensor(tensor, **kwargs):
if not isinstance(tensor, torch.Tensor):
+ # certain numpy types are not supported as being directly convertible to Pytorch tensors
+ if isinstance(tensor, np.ndarray) and tensor.dtype in UNSUPPORTED_TYPES:
+ tensor = tensor.astype(UNSUPPORTED_TYPES[tensor.dtype])
+
# if input data is not Tensor, convert it to Tensor first
tensor = torch.as_tensor(tensor, **kwargs)
if track_meta and not isinstance(tensor, monai.data.MetaTensor):
@@ -175,6 +182,13 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
elif has_cp and isinstance(data, cp_ndarray):
data = cp.asnumpy(data).astype(dtype, copy=False)
elif isinstance(data, (np.ndarray, float, int, bool)):
+ # Convert into a contiguous array first if the current dtype's size is smaller than the target dtype's size.
+ # This help improve the performance because (convert to contiguous array) -> (convert dtype) is faster
+ # than (convert dtype) -> (convert to contiguous array) when src dtype (e.g., uint8) is smaller than
+ # target dtype(e.g., float32) and we are going to convert it to contiguous array anyway later in this
+ # method.
+ if isinstance(data, np.ndarray) and data.ndim > 0 and data.dtype.itemsize < np.dtype(dtype).itemsize:
+ data = np.ascontiguousarray(data)
data = np.asarray(data, dtype=dtype)
elif isinstance(data, list):
list_ret = [convert_to_numpy(i, dtype=dtype) for i in data]
@@ -229,7 +243,7 @@ def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool
def convert_data_type(
data: Any,
output_type: Optional[Type[NdarrayTensor]] = None,
- device: Optional[torch.device] = None,
+ device: Union[None, str, torch.device] = None,
dtype: Union[DtypeLike, torch.dtype] = None,
wrap_sequence: bool = False,
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
@@ -293,7 +307,11 @@ def convert_data_type(
def convert_to_dst_type(
- src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False
+ src: Any,
+ dst: NdarrayTensor,
+ dtype: Union[DtypeLike, torch.dtype, None] = None,
+ wrap_sequence: bool = False,
+ device: Union[None, str, torch.device] = None,
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
"""
Convert source data to the same data type and device as the destination data.
@@ -307,12 +325,13 @@ def convert_to_dst_type(
dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
+ device: target device to put the converted Tensor data. If unspecified, `dst.device` will be used if possible.
See Also:
:func:`convert_data_type`
"""
- device = dst.device if isinstance(dst, torch.Tensor) else None
+ device = dst.device if device is None and isinstance(dst, torch.Tensor) else device
if dtype is None:
dtype = dst.dtype
diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py
index ba1f5d2589e..06999ebf1b4 100644
--- a/monai/visualize/class_activation_maps.py
+++ b/monai/visualize/class_activation_maps.py
@@ -125,10 +125,10 @@ def get_layer(self, layer_id: Union[str, Callable]):
def class_score(self, logits, class_idx):
return logits[:, class_idx].squeeze()
- def __call__(self, x, class_idx=None, retain_graph=False):
+ def __call__(self, x, class_idx=None, retain_graph=False, **kwargs):
train = self.model.training
self.model.eval()
- logits = self.model(x)
+ logits = self.model(x, **kwargs)
self.class_idx = logits.max(1)[-1] if class_idx is None else class_idx
acti, grad = None, None
if self.register_forward:
@@ -175,17 +175,18 @@ def __init__(
self.upsampler = upsampler
self.postprocessing = postprocessing
- def feature_map_size(self, input_size, device="cpu", layer_idx=-1):
+ def feature_map_size(self, input_size, device="cpu", layer_idx=-1, **kwargs):
"""
Computes the actual feature map size given `nn_module` and the target_layer name.
Args:
input_size: shape of the input tensor
device: the device used to initialise the input tensor
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
+ kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
Returns:
shape of the actual feature map.
"""
- return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx).shape
+ return self.compute_map(torch.zeros(*input_size, device=device), layer_idx=layer_idx, **kwargs).shape
def compute_map(self, x, class_idx=None, layer_idx=-1):
"""
@@ -286,8 +287,8 @@ def __init__(
)
self.fc_layers = fc_layers
- def compute_map(self, x, class_idx=None, layer_idx=-1):
- logits, acti, _ = self.nn_module(x)
+ def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs):
+ logits, acti, _ = self.nn_module(x, **kwargs)
acti = acti[layer_idx]
if class_idx is None:
class_idx = logits.max(1)[-1]
@@ -298,7 +299,7 @@ def compute_map(self, x, class_idx=None, layer_idx=-1):
output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0)
return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class
- def __call__(self, x, class_idx=None, layer_idx=-1):
+ def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs):
"""
Compute the activation map with upsampling and postprocessing.
@@ -306,11 +307,12 @@ def __call__(self, x, class_idx=None, layer_idx=-1):
x: input tensor, shape must be compatible with `nn_module`.
class_idx: index of the class to be visualized. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
+ kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
Returns:
activation maps
"""
- acti_map = self.compute_map(x, class_idx, layer_idx)
+ acti_map = self.compute_map(x, class_idx, layer_idx, **kwargs)
return self._upsample_and_post_process(acti_map, x)
@@ -356,15 +358,15 @@ class GradCAM(CAMBase):
"""
- def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
- _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)
+ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
+ _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
weights = grad.view(b, c, -1).mean(2).view(b, c, *[1] * len(spatial))
acti_map = (weights * acti).sum(1, keepdim=True)
return F.relu(acti_map)
- def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False):
+ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs):
"""
Compute the activation map with upsampling and postprocessing.
@@ -373,11 +375,12 @@ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False):
class_idx: index of the class to be visualized. Default to argmax(logits)
layer_idx: index of the target layer if there are multiple target layers. Defaults to -1.
retain_graph: whether to retain_graph for torch module backward call.
+ kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
Returns:
activation maps
"""
- acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx)
+ acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx, **kwargs)
return self._upsample_and_post_process(acti_map, x)
@@ -395,8 +398,8 @@ class GradCAMpp(GradCAM):
"""
- def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1):
- _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)
+ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
+ _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
alpha_nr = grad.pow(2)
diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py
index 32b8110b6d6..7ab6ef260da 100644
--- a/monai/visualize/gradient_based.py
+++ b/monai/visualize/gradient_based.py
@@ -22,7 +22,6 @@
trange, has_trange = optional_import("tqdm", name="trange")
-
__all__ = ["VanillaGrad", "SmoothGrad", "GuidedBackpropGrad", "GuidedBackpropSmoothGrad"]
@@ -45,12 +44,29 @@ def backward(ctx, grad_output):
class _GradReLU(torch.nn.Module):
+ """
+ A customized ReLU with the backward pass imputed for guided backpropagation (https://arxiv.org/abs/1412.6806).
+ """
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = _AutoGradReLU.apply(x)
return out
class VanillaGrad:
+ """
+ Given an input image ``x``, calling this class will perform the forward pass, then set to zero
+ all activations except one (defined by ``index``) and propagate back to the image to achieve a gradient-based
+ saliency map.
+
+ If ``index`` is None, argmax of the output logits will be used.
+
+ See also:
+
+ - Simonyan et al. Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
+ (https://arxiv.org/abs/1312.6034)
+ """
+
def __init__(self, model: torch.nn.Module) -> None:
if not isinstance(model, ModelWithHooks): # Convert to model with hooks if necessary
self._model = ModelWithHooks(model, target_layer_names=(), register_backward=True)
@@ -68,22 +84,26 @@ def model(self, m):
else:
self._model = m # replace the ModelWithHooks
- def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True) -> torch.Tensor:
+ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_graph=True, **kwargs) -> torch.Tensor:
if x.shape[0] != 1:
raise ValueError("expect batch size of 1")
x.requires_grad = True
- self._model(x, class_idx=index, retain_graph=retain_graph)
- grad: torch.Tensor = x.grad.detach()
+ self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
+ grad: torch.Tensor = x.grad.detach() # type: ignore
return grad
- def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
- return self.get_grad(x, index)
+ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
+ return self.get_grad(x, index, **kwargs)
class SmoothGrad(VanillaGrad):
"""
+ Compute averaged sensitivity map based on ``n_samples`` (Gaussian additive) of noisy versions
+ of the input image ``x``.
+
See also:
+
- Smilkov et al. SmoothGrad: removing noise by adding noise https://arxiv.org/abs/1706.03825
"""
@@ -105,7 +125,7 @@ def __init__(
else:
self.range = range
- def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
+ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
stdev = (self.stdev_spread * (x.max() - x.min())).item()
total_gradients = torch.zeros_like(x)
for _ in self.range(self.n_samples):
@@ -115,7 +135,7 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) ->
x_plus_noise = x_plus_noise.detach()
# get gradient and accumulate
- grad = self.get_grad(x_plus_noise, index)
+ grad = self.get_grad(x_plus_noise, index, **kwargs)
total_gradients += (grad * grad) if self.magnitude else grad
# average
@@ -126,12 +146,26 @@ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) ->
class GuidedBackpropGrad(VanillaGrad):
- def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
+ """
+ Based on Springenberg and Dosovitskiy et al. https://arxiv.org/abs/1412.6806,
+ compute gradient-based saliency maps by backpropagating positive graidents and inputs (see ``_AutoGradReLU``).
+
+ See also:
+
+ - Springenberg and Dosovitskiy et al. Striving for Simplicity: The All Convolutional Net
+ (https://arxiv.org/abs/1412.6806)
+ """
+
+ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
- return super().__call__(x, index)
+ return super().__call__(x, index, **kwargs)
class GuidedBackpropSmoothGrad(SmoothGrad):
- def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None) -> torch.Tensor:
+ """
+ Compute gradient-based saliency maps based on both ``GuidedBackpropGrad`` and ``SmoothGrad``.
+ """
+
+ def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
with replace_modules_temp(self.model, "relu", _GradReLU(), strict_match=False):
- return super().__call__(x, index)
+ return super().__call__(x, index, **kwargs)
diff --git a/monai/visualize/occlusion_sensitivity.py b/monai/visualize/occlusion_sensitivity.py
index d87b93396ae..03c69f8978e 100644
--- a/monai/visualize/occlusion_sensitivity.py
+++ b/monai/visualize/occlusion_sensitivity.py
@@ -10,89 +10,18 @@
# limitations under the License.
from collections.abc import Sequence
-from functools import partial
from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
+from monai.data.meta_tensor import MetaTensor
from monai.networks.utils import eval_mode
+from monai.transforms import Compose, GaussianSmooth, Lambda, ScaleIntensity, SpatialCrop
+from monai.utils import deprecated_arg, ensure_tuple_rep
from monai.visualize.visualizer import default_upsampler
-try:
- from tqdm import trange
-
- trange = partial(trange, desc="Computing occlusion sensitivity")
-except (ImportError, AttributeError):
- trange = range
-
-# For stride two (for example),
-# if input array is: |0|1|2|3|4|5|6|7|
-# downsampled output is: | 0 | 1 | 2 | 3 |
-# So the upsampling should do it by the corners of the image, not their centres
-default_upsampler = partial(default_upsampler, align_corners=True)
-
-
-def _check_input_image(image):
- """Check that the input image is as expected."""
- # Only accept batch size of 1
- if image.shape[0] > 1:
- raise RuntimeError("Expected batch size of 1.")
-
-
-def _check_input_bounding_box(b_box, im_shape):
- """Check that the bounding box (if supplied) is as expected."""
- # If no bounding box has been supplied, set min and max to None
- if b_box is None:
- b_box_min = b_box_max = None
-
- # Bounding box has been supplied
- else:
- # Should be twice as many elements in `b_box` as `im_shape`
- if len(b_box) != 2 * len(im_shape):
- raise ValueError("Bounding box should contain upper and lower for all dimensions (except batch number)")
-
- # If any min's or max's are -ve, set them to 0 and im_shape-1, respectively.
- b_box_min = np.array(b_box[::2])
- b_box_max = np.array(b_box[1::2])
- b_box_min[b_box_min < 0] = 0
- b_box_max[b_box_max < 0] = im_shape[b_box_max < 0] - 1
- # Check all max's are < im_shape
- if np.any(b_box_max >= im_shape):
- raise ValueError("Max bounding box should be < image size for all values")
- # Check all min's are <= max's
- if np.any(b_box_min > b_box_max):
- raise ValueError("Min bounding box should be <= max for all values")
-
- return b_box_min, b_box_max
-
-
-def _append_to_sensitivity_ims(model, batch_images, sensitivity_ims):
- """Infer given images. Append to previous evaluations. Store each class separately."""
- batch_images = torch.cat(batch_images, dim=0)
- scores = model(batch_images).detach()
- for i in range(scores.shape[1]):
- sensitivity_ims[i] = torch.cat((sensitivity_ims[i], scores[:, i]))
- return sensitivity_ims
-
-
-def _get_as_np_array(val, numel):
- # If not a sequence, then convert scalar to numpy array
- if not isinstance(val, Sequence):
- out = np.full(numel, val, dtype=np.int32)
- out[0] = 1 # mask_size and stride always 1 in channel dimension
- else:
- # Convert to numpy array and check dimensions match
- out = np.array(val, dtype=np.int32)
- # Add stride of 1 to the channel direction (since user input was only for spatial dimensions)
- out = np.insert(out, 0, 1)
- if out.size != numel:
- raise ValueError(
- "If supplying stride/mask_size as sequence, number of elements should match number of spatial dimensions."
- )
- return out
-
class OcclusionSensitivity:
"""
@@ -124,168 +53,231 @@ class OcclusionSensitivity:
# densenet 2d
from monai.networks.nets import DenseNet121
from monai.visualize import OcclusionSensitivity
+ import torch
model_2d = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
occ_sens = OcclusionSensitivity(nn_module=model_2d)
- occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), b_box=[-1, -1, 2, 40, 1, 62])
+ occ_map, most_probable_class = occ_sens(x=torch.rand((1, 1, 48, 64)), b_box=[2, 40, 1, 62])
# densenet 3d
from monai.networks.nets import DenseNet
from monai.visualize import OcclusionSensitivity
model_3d = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
- occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10, stride=3)
- occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), b_box=[-1, -1, 1, 3, -1, -1, -1, -1])
+ occ_sens = OcclusionSensitivity(nn_module=model_3d, n_batch=10)
+ occ_map, most_probable_class = occ_sens(torch.rand(1, 1, 6, 6, 6), b_box=[1, 3, -1, -1, -1, -1])
See Also:
- :py:class:`monai.visualize.occlusion_sensitivity.OcclusionSensitivity.`
"""
+ @deprecated_arg(
+ name="pad_val",
+ since="1.0",
+ removed="1.2",
+ msg_suffix="Please use `mode`. For backwards compatibility, use `mode=mean_img`.",
+ )
+ @deprecated_arg(name="stride", since="1.0", removed="1.2", msg_suffix="Please use `overlap`.")
+ @deprecated_arg(name="per_channel", since="1.0", removed="1.2")
+ @deprecated_arg(name="upsampler", since="1.0", removed="1.2")
def __init__(
self,
nn_module: nn.Module,
pad_val: Optional[float] = None,
- mask_size: Union[int, Sequence] = 15,
- n_batch: int = 128,
+ mask_size: Union[int, Sequence] = 16,
+ n_batch: int = 16,
stride: Union[int, Sequence] = 1,
per_channel: bool = True,
upsampler: Optional[Callable] = default_upsampler,
verbose: bool = True,
+ mode: Union[str, float, Callable] = "gaussian",
+ overlap: float = 0.25,
+ activate: Union[bool, Callable] = True,
) -> None:
- """Occlusion sensitivity constructor.
+ """
+ Occlusion sensitivity constructor.
Args:
nn_module: Classification model to use for inference
- pad_val: When occluding part of the image, which values should we put
- in the image? If ``None`` is used, then the average of the image will be used.
- mask_size: Size of box to be occluded, centred on the central voxel. To ensure that the occluded area
- is correctly centred, ``mask_size`` and ``stride`` should both be odd or even.
+ mask_size: Size of box to be occluded, centred on the central voxel. If a single number
+ is given, this is used for all dimensions. If a sequence is given, this is used for each dimension
+ individually.
n_batch: Number of images in a batch for inference.
- stride: Stride in spatial directions for performing occlusions. Can be single
- value or sequence (for varying stride in the different directions).
- Should be >= 1. Striding in the channel direction depends on the `per_channel` argument.
- per_channel: If `True`, `mask_size` and `stride` both equal 1 in the channel dimension. If `False`,
- then both `mask_size` equals the number of channels in the image. If `True`, the output image will be:
- `[B, C, H, W, D, num_seg_classes]`. Else, will be `[B, 1, H, W, D, num_seg_classes]`
- upsampler: An upsampling method to upsample the output image. Default is
- N-dimensional linear (bilinear, trilinear, etc.) depending on num spatial
- dimensions of input.
- verbose: Use ``tqdm.trange`` output (if available).
- """
+ verbose: Use progress bar (if ``tqdm`` available).
+ mode: what should the occluded region be replaced with? If a float is given, that value will be used
+ throughout the occlusion. Else, ``gaussian``, ``mean_img`` and ``mean_patch`` can be supplied:
+
+ * ``gaussian``: occluded region is multiplied by 1 - gaussian kernel. In this fashion, the occlusion
+ will be 0 at the center and will be unchanged towards the edges, varying smoothly between. When
+ gaussian is used, a weighted average will be used to combine overlapping regions. This will be
+ done using the gaussian (not 1-gaussian) as occluded regions count more.
+ * ``mean_patch``: occluded region will be replaced with the mean of occluded region.
+ * ``mean_img``: occluded region will be replaced with the mean of the whole image.
+
+ overlap: overlap between inferred regions. Should be in range 0<=x<1.
+ activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any
+ activation. If ``callable``, use callable on inferred outputs.
+ """
self.nn_module = nn_module
- self.upsampler = upsampler
- self.pad_val = pad_val
self.mask_size = mask_size
self.n_batch = n_batch
- self.stride = stride
- self.per_channel = per_channel
self.verbose = verbose
+ self.overlap = overlap
+ self.activate = activate
+ # mode
+ if isinstance(mode, str) and mode not in ("gaussian", "mean_patch", "mean_img"):
+ raise NotImplementedError
+ self.mode = mode
+
+ @staticmethod
+ def constant_occlusion(x: torch.Tensor, val: float, mask_size: Sequence) -> Tuple[float, torch.Tensor]:
+ """Occlude with a constant occlusion. Multiplicative is zero, additive is constant value."""
+ ones = torch.ones((*x.shape[:2], *mask_size), device=x.device, dtype=x.dtype)
+ return 0, ones * val
+
+ @staticmethod
+ def gaussian_occlusion(x: torch.Tensor, mask_size, sigma=0.25) -> Tuple[torch.Tensor, float]:
+ """
+ For Gaussian occlusion, Multiplicative is 1-Gaussian, additive is zero.
+ Default sigma of 0.25 empirically shown to give reasonable kernel, see here:
+ https://github.com/Project-MONAI/MONAI/pull/5230#discussion_r984520714.
+ """
+ kernel = torch.zeros((x.shape[1], *mask_size), device=x.device, dtype=x.dtype)
+ spatial_shape = kernel.shape[1:]
+ # all channels (as occluded shape already takes into account per_channel), center in spatial dimensions
+ center = [slice(None)] + [slice(s // 2, s // 2 + 1) for s in spatial_shape]
+ # place value of 1 at center
+ kernel[center] = 1.0
+ # Smooth with sigma equal to quarter of image, flip +ve/-ve so largest values are at edge
+ # and smallest at center. Scale to [0, 1].
+ gaussian = Compose(
+ [GaussianSmooth(sigma=[b * sigma for b in spatial_shape]), Lambda(lambda x: -x), ScaleIntensity()]
+ )
+ # transform and add batch
+ mul: torch.Tensor = gaussian(kernel)[None]
+
+ return mul, 0
+
+ @staticmethod
+ def predictor(
+ cropped_grid: torch.Tensor,
+ nn_module: nn.Module,
+ x: torch.Tensor,
+ mul: Union[torch.Tensor, float],
+ add: Union[torch.Tensor, float],
+ mask_size: Sequence,
+ occ_mode: str,
+ activate: Union[bool, Callable],
+ module_kwargs,
+ ) -> torch.Tensor:
+ """
+ Predictor function to be passed to the sliding window inferer. Takes a cropped meshgrid,
+ referring to the coordinates in the input image. We use the index of the top-left corner
+ in combination ``mask_size`` to figure out which region of the image is to be occluded. The
+ occlusion is performed on the original image, ``x``, using ``cropped_region * mul + add``. ``mul``
+ and ``add`` are sometimes pre-computed (e.g., a constant Gaussian blur), or they are
+ sometimes calculated on the fly (e.g., the mean of the occluded patch). For this reason
+ ``occ_mode`` is given. Lastly, ``activate`` is used to activate after each call of the model.
- def _compute_occlusion_sensitivity(self, x, b_box):
-
- # Get bounding box
- im_shape = np.array(x.shape[1:])
- b_box_min, b_box_max = _check_input_bounding_box(b_box, im_shape)
-
- # Get the number of prediction classes
- num_classes = self.nn_module(x).numel()
-
- # If pad val not supplied, get the mean of the image
- pad_val = x.mean() if self.pad_val is None else self.pad_val
-
- # List containing a batch of images to be inferred
- batch_images = []
-
- # List of sensitivity images, one for each inferred class
- sensitivity_ims = num_classes * [torch.empty(0, dtype=torch.float32, device=x.device)]
-
- # If no bounding box supplied, output shape is same as input shape.
- # If bounding box is present, shape is max - min + 1
- output_im_shape = im_shape if b_box is None else b_box_max - b_box_min + 1
-
- # Get the stride and mask_size as numpy arrays
- stride = _get_as_np_array(self.stride, len(im_shape))
- mask_size = _get_as_np_array(self.mask_size, len(im_shape))
-
- # If not doing it on a per-channel basis, then the output image will have 1 output channel
- # (since all will be occluded together)
- if not self.per_channel:
- output_im_shape[0] = 1
- stride[0] = x.shape[1]
- mask_size[0] = x.shape[1]
-
- # For each dimension, ...
- for o, s in zip(output_im_shape, stride):
- # if the size is > 1, then check that the stride is a factor of the output image shape
- if o > 1 and o % s != 0:
- raise ValueError(
- "Stride should be a factor of the image shape. Im shape "
- + f"(taking bounding box into account): {output_im_shape}, stride: {stride}"
- )
-
- # to ensure the occluded area is nicely centred if stride is even, ensure that so is the mask_size
- if np.any(mask_size % 2 != stride % 2):
- raise ValueError(
- "Stride and mask size should both be odd or even (element-wise). "
- + f"``stride={stride}``, ``mask_size={mask_size}``"
- )
-
- downsampled_im_shape = (output_im_shape / stride).astype(np.int32)
- downsampled_im_shape[downsampled_im_shape == 0] = 1 # make sure dimension sizes are >= 1
- num_required_predictions = np.prod(downsampled_im_shape)
-
- # Get bottom left and top right corners of occluded region
- lower_corner = (stride - mask_size) // 2
- upper_corner = (stride + mask_size) // 2
-
- # Loop 1D over image
- verbose_range = trange if self.verbose else range
- for i in verbose_range(num_required_predictions):
- # Get corresponding ND index
- idx = np.unravel_index(i, downsampled_im_shape)
- # Multiply by stride
- idx *= stride
- # If a bounding box is being used, we need to add on
- # the min to shift to start of region of interest
- if b_box_min is not None:
- idx += b_box_min
-
- # Get min and max index of box to occlude (and make sure it's in bounds)
- min_idx = np.maximum(idx + lower_corner, 0)
- max_idx = np.minimum(idx + upper_corner, im_shape)
-
- # Clone and replace target area with `pad_val`
- occlu_im = x.detach().clone()
- occlu_im[(...,) + tuple(slice(i, j) for i, j in zip(min_idx, max_idx))] = pad_val
-
- # Add to list
- batch_images.append(occlu_im)
-
- # Once the batch is complete (or on last iteration)
- if len(batch_images) == self.n_batch or i == num_required_predictions - 1:
- # Do the predictions and append to sensitivity maps
- sensitivity_ims = _append_to_sensitivity_ims(self.nn_module, batch_images, sensitivity_ims)
- # Clear lists
- batch_images = []
-
- # Reshape to match downsampled image, and unsqueeze to add batch dimension back in
- for i in range(num_classes):
- sensitivity_ims[i] = sensitivity_ims[i].reshape(tuple(downsampled_im_shape)).unsqueeze(0)
-
- return sensitivity_ims, output_im_shape
-
- def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ Args:
+ cropped_grid: subsection of the meshgrid, where each voxel refers to the coordinate of
+ the input image. The meshgrid is created by the ``OcclusionSensitivity`` class, and
+ the generation of the subset is determined by ``sliding_window_inference``.
+ nn_module: module to call on data.
+ x: the image that was originally passed into ``OcclusionSensitivity.__call__``.
+ mul: occluded region will be multiplied by this. Can be ``torch.Tensor`` or ``float``.
+ add: after multiplication, this is added to the occluded region. Can be ``torch.Tensor`` or ``float``.
+ mask_size: Size of box to be occluded, centred on the central voxel. Should be
+ a sequence, one value for each spatial dimension.
+ occ_mode: might be used to calculate ``mul`` and ``add`` on the fly.
+ activate: if ``True``, do softmax activation if num_channels > 1 else do ``sigmoid``. If ``False``, don't do any
+ activation. If ``callable``, use callable on inferred outputs.
+ module_kwargs: kwargs to be passed onto module when inferring
+ """
+ n_batch = cropped_grid.shape[0]
+ sd = cropped_grid.ndim - 2
+ # start with copies of x to infer
+ im = torch.repeat_interleave(x, n_batch, 0)
+ # get coordinates of top left corner of occluded region (possible because we use meshgrid)
+ corner_coord_slices = [slice(None)] * 2 + [slice(1)] * sd
+ top_corners = cropped_grid[corner_coord_slices]
+
+ # replace occluded regions
+ for b, t in enumerate(top_corners):
+ # starting from corner, get the slices to extract the occluded region from the image
+ slices = [slice(b, b + 1), slice(None)] + [slice(int(j), int(j) + m) for j, m in zip(t, mask_size)]
+ to_occlude = im[slices]
+ if occ_mode == "mean_patch":
+ add, mul = OcclusionSensitivity.constant_occlusion(x, to_occlude.mean().item(), mask_size)
+
+ if callable(occ_mode):
+ to_occlude = occ_mode(x, to_occlude)
+ else:
+ to_occlude = to_occlude * mul + add
+ if add is None or mul is None:
+ raise RuntimeError("Shouldn't be here, something's gone wrong...")
+ im[slices] = to_occlude
+ # infer
+ out: torch.Tensor = nn_module(im, **module_kwargs)
+
+ # if activation is callable, call it
+ if callable(activate):
+ out = activate(out)
+ # else if True (should be boolean), sigmoid if n_chan == 1 else softmax
+ elif activate:
+ out = out.sigmoid() if x.shape[1] == 1 else out.softmax(1)
+
+ # the output will have shape [B,C] where C is number of channels output by model (inference classes)
+ # we need to return it to sliding window inference with shape [B,C,H,W,[D]], so add dims and repeat values
+ for m in mask_size:
+ out = torch.repeat_interleave(out.unsqueeze(-1), m, dim=-1)
+
+ return out
+
+ @staticmethod
+ def crop_meshgrid(
+ grid: MetaTensor, b_box: Sequence, mask_size: Sequence
+ ) -> Tuple[MetaTensor, SpatialCrop, Sequence]:
+ """Crop the meshgrid so we only perform occlusion sensitivity on a subsection of the image."""
+ # distance from center of mask to edge is -1 // 2.
+ mask_edge = [(m - 1) // 2 for m in mask_size]
+ bbox_min = [max(b - m, 0) for b, m in zip(b_box[::2], mask_edge)]
+ bbox_max = []
+ for b, m, s in zip(b_box[1::2], mask_edge, grid.shape[2:]):
+ # if bbox is -ve for that dimension, no cropping so use current image size
+ if b == -1:
+ bbox_max.append(s)
+ # else bounding box plus distance to mask edge. Make sure it's not bigger than the size of the image
+ else:
+ bbox_max.append(min(b + m, s))
+ # bbox_max = [min(b + m, s) if b >= 0 else s for b, m, s in zip(b_box[1::2], mask_edge, grid.shape[2:])]
+ # No need for batch and channel slices. Batch will be removed and added back in, and
+ # SpatialCrop doesn't act on the first dimension anyway.
+ slices = [slice(s, e) for s, e in zip(bbox_min, bbox_max)]
+ cropper = SpatialCrop(roi_slices=slices)
+ cropped: MetaTensor = cropper(grid[0])[None] # type: ignore
+ mask_size = list(mask_size)
+ for i, s in enumerate(cropped.shape[2:]):
+ mask_size[i] = min(s, mask_size[i])
+ return cropped, cropper, mask_size
+
+ def __call__(
+ self, x: torch.Tensor, b_box: Optional[Sequence] = None, **kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Image to use for inference. Should be a tensor consisting of 1 batch.
b_box: Bounding box on which to perform the analysis. The output image will be limited to this size.
- There should be a minimum and maximum for all dimensions except batch: ``[min1, max1, min2, max2,...]``.
+ There should be a minimum and maximum for all spatial dimensions: ``[min1, max1, min2, max2,...]``.
* By default, the whole image will be used. Decreasing the size will speed the analysis up, which might
be useful for larger images.
* Min and max are inclusive, so ``[0, 63, ...]`` will have size ``(64, ...)``.
* Use -ve to use ``min=0`` and ``max=im.shape[x]-1`` for xth dimension.
+ * N.B.: we add half of the mask size to the bounding box to ensure that the region of interest has a
+ sufficiently large area surrounding it.
+ kwargs: any extra arguments to be passed on to the module as part of its `__call__`.
Returns:
* Occlusion map:
@@ -298,29 +290,73 @@ def __call__(self, x: torch.Tensor, b_box: Optional[Sequence] = None) -> Tuple[t
* The most probable class when the corresponding part of the image is occluded (``argmax(dim=-1)``).
Both images will be cropped if a bounding box used, but voxel sizes will always match the input.
"""
+ if x.shape[0] > 1:
+ raise ValueError("Expected batch size of 1.")
+
+ sd = x.ndim - 2
+ mask_size: Sequence = ensure_tuple_rep(self.mask_size, sd)
+
+ # get the meshgrid (so that sliding_window_inference can tell us which bit to occlude)
+ grid: MetaTensor = MetaTensor(
+ np.stack(np.meshgrid(*[np.arange(0, i) for i in x.shape[2:]], indexing="ij"))[None],
+ device=x.device,
+ dtype=x.dtype,
+ )
+ # if bounding box given, crop the grid to only infer subsections of the image
+ if b_box is not None:
+ grid, cropper, mask_size = self.crop_meshgrid(grid, b_box, mask_size)
+
+ # check that the grid is bigger than the mask size
+ if any(m > g for g, m in zip(grid.shape[2:], mask_size)):
+ raise ValueError(f"Image (spatial shape) {grid.shape[2:]} should be bigger than mask {mask_size}.")
+
+ # get additive and multiplicative factors if they are unchanged for all patches (i.e., not mean_patch)
+ add: Optional[Union[float, torch.Tensor]]
+ mul: Optional[Union[float, torch.Tensor]]
+ # multiply by 0, add value
+ if isinstance(self.mode, float):
+ mul, add = self.constant_occlusion(x, self.mode, mask_size)
+ # multiply by 0, add mean of image
+ elif self.mode == "mean_img":
+ mul, add = self.constant_occlusion(x, x.mean().item(), mask_size)
+ # for gaussian, additive = 0, multiplicative = gaussian
+ elif self.mode == "gaussian":
+ mul, add = self.gaussian_occlusion(x, mask_size)
+ # else will be determined on each patch individually so calculated later
+ else:
+ add, mul = None, None
with eval_mode(self.nn_module):
+ # needs to go here to avoid cirular import
+ from monai.inferers import sliding_window_inference
+
+ sensitivity_im: MetaTensor = sliding_window_inference( # type: ignore
+ grid,
+ roi_size=mask_size,
+ sw_batch_size=self.n_batch,
+ predictor=OcclusionSensitivity.predictor,
+ overlap=self.overlap,
+ mode="gaussian" if self.mode == "gaussian" else "constant",
+ progress=self.verbose,
+ nn_module=self.nn_module,
+ x=x,
+ add=add,
+ mul=mul,
+ mask_size=mask_size,
+ occ_mode=self.mode,
+ activate=self.activate,
+ module_kwargs=kwargs,
+ )
- # Check input arguments
- _check_input_image(x)
-
- # Generate sensitivity images
- sensitivity_ims_list, output_im_shape = self._compute_occlusion_sensitivity(x, b_box)
-
- # Loop over image for each classification
- for i, sens_i in enumerate(sensitivity_ims_list):
- # upsample
- if self.upsampler is not None:
- if len(sens_i.shape) != len(x.shape):
- raise AssertionError
- if np.any(sens_i.shape != x.shape):
- img_spatial = tuple(output_im_shape[1:])
- sensitivity_ims_list[i] = self.upsampler(img_spatial)(sens_i)
-
- # Convert list of tensors to tensor
- sensitivity_ims = torch.stack(sensitivity_ims_list, dim=-1)
-
- # The most probable class is the max in the classification dimension (last)
- most_probable_class = sensitivity_ims.argmax(dim=-1)
-
- return sensitivity_ims, most_probable_class
+ if b_box is not None:
+ # undo the cropping that was applied to the meshgrid
+ sensitivity_im = cropper.inverse(sensitivity_im[0])[None] # type: ignore
+ # crop using the bounding box (ignoring the mask size this time)
+ bbox_min = [max(b, 0) for b in b_box[::2]]
+ bbox_max = [b if b > 0 else s for b, s in zip(b_box[1::2], x.shape[2:])]
+ cropper = SpatialCrop(roi_start=bbox_min, roi_end=bbox_max)
+ sensitivity_im = cropper(sensitivity_im[0])[None] # type: ignore
+
+ # The most probable class is the max in the classification dimension (1)
+ most_probable_class = sensitivity_im.argmax(dim=1, keepdim=True)
+ return sensitivity_im, most_probable_class
diff --git a/monai/visualize/visualizer.py b/monai/visualize/visualizer.py
index 5f19e4f63ff..05ebb2e2800 100644
--- a/monai/visualize/visualizer.py
+++ b/monai/visualize/visualizer.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Callable
import torch
diff --git a/pyproject.toml b/pyproject.toml
index 59db27e134b..2505eb0d45d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,7 +2,7 @@
requires = [
"wheel",
"setuptools",
- "torch>=1.7",
+ "torch>=1.8",
"ninja",
]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 2568bf9c1eb..8b3faee47f2 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,9 +1,9 @@
# Full requirements for developments
-r requirements-min.txt
-pytorch-ignite==0.4.9
+pytorch-ignite==0.4.10
gdown>=4.4.0
scipy
-itk>=5.2
+itk>=5.2; python_version < "3.10"
nibabel
pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
tensorboard
@@ -31,14 +31,14 @@ Sphinx==3.5.3
recommonmark==0.6.0
sphinx-autodoc-typehints==1.11.1
sphinx-rtd-theme==0.5.2
-cucim==22.2.1; platform_system == "Linux"
+cucim==22.8.1; platform_system == "Linux"
openslide-python==1.1.2
imagecodecs; platform_system == "Linux"
tifffile; platform_system == "Linux"
pandas
requests
einops
-transformers
+transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157
mlflow
matplotlib!=3.5.0
tensorboardX
@@ -50,3 +50,5 @@ pynrrd
pre-commit
pydicom
h5py
+nni
+optuna
diff --git a/requirements.txt b/requirements.txt
index 14eb2b30e9e..ba7d7be6d79 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1,2 @@
-torch>=1.7
+torch>=1.8
numpy>=1.17
diff --git a/runtests.sh b/runtests.sh
index 78f16727455..490b50b1365 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -14,13 +14,6 @@
# script for running all tests
set -e
-# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354
-protobuf_major_version=$(pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
-if [ "$protobuf_major_version" -ge "4" ]
-then
- export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
-fi
-
# output formatting
separator=""
blue=""
@@ -118,6 +111,13 @@ function print_usage {
exit 1
}
+# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354
+protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
+if [ "$protobuf_major_version" -ge "4" ]
+then
+ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
+fi
+
function check_import {
echo "Python: ${PY_EXE}"
${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -c "import monai"
@@ -681,5 +681,5 @@ if [ $doCoverage = true ]
then
echo "${separator}${blue}coverage${noColor}"
${cmdPrefix}${PY_EXE} -m coverage combine --append .coverage/
- ${cmdPrefix}${PY_EXE} -m coverage report
+ ${cmdPrefix}${PY_EXE} -m coverage report --ignore-errors
fi
diff --git a/setup.cfg b/setup.cfg
index 09219bfc32a..8d882af2824 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -24,29 +24,30 @@ setup_requires =
torch
ninja
install_requires =
- torch>=1.7
+ torch>=1.8
numpy>=1.17
[options.extras_require]
all =
nibabel
+ ninja
scikit-image>=0.14.2
pillow
tensorboard
gdown>=4.4.0
- pytorch-ignite==0.4.9
+ pytorch-ignite==0.4.10
torchvision
itk>=5.2
tqdm>=4.47.0
lmdb
psutil
- cucim>=21.8.2
+ cucim>=22.8.1
openslide-python==1.1.2
tifffile
imagecodecs
pandas
einops
- transformers
+ transformers<4.22
mlflow
matplotlib
tensorboardX
@@ -56,8 +57,12 @@ all =
pynrrd
pydicom
h5py
+ nni
+ optuna
nibabel =
nibabel
+ninja =
+ ninja
skimage =
scikit-image>=0.14.2
pillow =
@@ -67,7 +72,7 @@ tensorboard =
gdown =
gdown>=4.4.0
ignite =
- pytorch-ignite==0.4.9
+ pytorch-ignite==0.4.10
torchvision =
torchvision
itk =
@@ -79,7 +84,7 @@ lmdb =
psutil =
psutil
cucim =
- cucim>=21.8.2
+ cucim>=22.8.1
openslide =
openslide-python==1.1.2
tifffile =
@@ -91,7 +96,7 @@ pandas =
einops =
einops
transformers =
- transformers
+ transformers<4.22
mlflow =
mlflow
matplotlib =
diff --git a/tests/__init__.py b/tests/__init__.py
index 4639a584969..0d6e28a6798 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -29,7 +29,7 @@ def _enter_pr_4800(self):
return self
-# workaround for https://bugs.python.org/issue29620
+# FIXME: workaround for https://bugs.python.org/issue29620
try:
# Suppression for issue #494: tests/__init__.py:34: error: Cannot assign to a method
unittest.case._AssertWarnsContext.__enter__ = _enter_pr_4800 # type: ignore
diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py
index cf8254b614b..a79038f4be8 100644
--- a/tests/hvd_evenly_divisible_all_gather.py
+++ b/tests/hvd_evenly_divisible_all_gather.py
@@ -13,6 +13,7 @@
from monai.utils import evenly_divisible_all_gather
from monai.utils.module import optional_import
+from tests.utils import assert_allclose
hvd, has_hvd = optional_import("horovod", name="torch")
@@ -37,13 +38,13 @@ def _run(self):
data3 = torch.tensor(8)
result1 = evenly_divisible_all_gather(data=data1, concat=True)
- torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))
+ assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))
result2 = evenly_divisible_all_gather(data=data2, concat=False)
for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]):
- torch.testing.assert_allclose(r, e)
+ assert_allclose(r, e)
result3 = evenly_divisible_all_gather(data=data3, concat=False)
for r in result3:
- torch.testing.assert_allclose(r.ndimension(), 0)
+ assert_allclose(r.ndimension(), 0)
if __name__ == "__main__":
diff --git a/tests/min_tests.py b/tests/min_tests.py
index f33af553f30..7c06d213744 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -29,12 +29,18 @@ def run_testsuit():
exclude_cases = [ # these cases use external dependencies
"test_ahnet",
"test_arraydataset",
+ "test_auto3dseg_ensemble",
+ "test_auto3dseg_hpo",
+ "test_auto3dseg",
"test_cachedataset",
"test_cachedataset_parallel",
"test_cachedataset_persistent_workers",
"test_cachentransdataset",
- "test_contrastive_loss",
"test_check_missing_files",
+ "test_compute_ho_ver_maps",
+ "test_compute_ho_ver_maps_d",
+ "test_compute_panoptic_quality",
+ "test_contrastive_loss",
"test_csv_dataset",
"test_csv_iterable_dataset",
"test_cumulative_average_dist",
@@ -56,6 +62,8 @@ def run_testsuit():
"test_foreground_mask",
"test_foreground_maskd",
"test_global_mutual_information_loss",
+ "test_grid_patch",
+ "test_gmm",
"test_handler_checkpoint_loader",
"test_handler_checkpoint_saver",
"test_handler_classification_saver",
@@ -68,6 +76,8 @@ def run_testsuit():
"test_handler_hausdorff_distance",
"test_handler_lr_scheduler",
"test_handler_mean_dice",
+ "test_handler_panoptic_quality",
+ "test_handler_mean_iou",
"test_handler_metrics_saver",
"test_handler_metrics_saver_dist",
"test_handler_mlflow",
@@ -88,6 +98,7 @@ def run_testsuit():
"test_hausdorff_distance",
"test_header_correct",
"test_hilbert_transform",
+ "test_hovernet_loss",
"test_image_dataset",
"test_image_rw",
"test_img2tensorboard",
@@ -98,6 +109,7 @@ def run_testsuit():
"test_integration_workflows",
"test_integration_workflows_gan",
"test_integration_bundle_run",
+ "test_integration_autorunner",
"test_invert",
"test_invertd",
"test_iterable_dataset",
@@ -130,11 +142,14 @@ def run_testsuit():
"test_png_saver",
"test_prepare_batch_default",
"test_prepare_batch_extra_input",
+ "test_prepare_batch_hovernet",
+ "test_rand_grid_patch",
"test_rand_rotate",
"test_rand_rotated",
"test_rand_zoom",
"test_rand_zoomd",
"test_randtorchvisiond",
+ "test_resample_backends",
"test_resize",
"test_resized",
"test_resample_to_match",
@@ -162,7 +177,6 @@ def run_testsuit():
"test_vitautoenc",
"test_write_metrics_reports",
"test_wsireader",
- "test_wsireader_new",
"test_zoom",
"test_zoom_affine",
"test_zoomd",
@@ -198,10 +212,7 @@ def run_testsuit():
from monai.utils.module import load_submodules
_, err_mod = load_submodules(sys.modules["monai"], True)
- if err_mod:
- print(err_mod)
- # expecting that only engines and handlers are not imported
- assert sorted(err_mod) == ["monai.engines", "monai.handlers"]
+ assert not err_mod, f"err_mod={err_mod} not empty"
# testing all modules
test_runner = unittest.TextTestRunner(stream=sys.stdout, verbosity=2)
diff --git a/tests/profile_subclass/cprofile_profiling.py b/tests/profile_subclass/cprofile_profiling.py
index a6c940c9c0d..0befa0f450e 100644
--- a/tests/profile_subclass/cprofile_profiling.py
+++ b/tests/profile_subclass/cprofile_profiling.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Profiling MetaTensor
"""
diff --git a/tests/profile_subclass/min_classes.py b/tests/profile_subclass/min_classes.py
index 87c0ce671dc..702ba73e210 100644
--- a/tests/profile_subclass/min_classes.py
+++ b/tests/profile_subclass/min_classes.py
@@ -8,8 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
"""
Minimal subclassing as baselines
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
diff --git a/tests/profile_subclass/profiling.py b/tests/profile_subclass/profiling.py
index 28740e82e19..46047b619c0 100644
--- a/tests/profile_subclass/profiling.py
+++ b/tests/profile_subclass/profiling.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
diff --git a/tests/profile_subclass/pyspy_profiling.py b/tests/profile_subclass/pyspy_profiling.py
index 302bfd39c3f..1caeee69e7d 100644
--- a/tests/profile_subclass/pyspy_profiling.py
+++ b/tests/profile_subclass/pyspy_profiling.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""
To be used with py-spy, comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
diff --git a/tests/test_activations.py b/tests/test_activations.py
index a06316b2538..503ca0a3507 100644
--- a/tests/test_activations.py
+++ b/tests/test_activations.py
@@ -38,6 +38,15 @@
]
)
+ TEST_CASES.append(
+ [
+ {"sigmoid": False, "softmax": True, "other": None, "unused": True, "dim": 1},
+ p([[[0.0, 1.0]], [[2.0, 3.0]]]),
+ p([[[1.0, 1.0]], [[1.0, 1.0]]]),
+ (2, 1, 2),
+ ]
+ )
+
TEST_CASES.append(
[
{"sigmoid": False, "softmax": False, "other": torch.tanh},
@@ -94,7 +103,7 @@ def _compare(ret, out, shape):
def test_monai_activations_value_shape(self, input_param, img, out, expected_shape):
act = Act[input_param]()
result = act(img)
- torch.testing.assert_allclose(result, out, rtol=1e-2, atol=1e-5)
+ assert_allclose(result, out, rtol=1e-2, atol=1e-5)
self.assertTupleEqual(result.shape, expected_shape)
diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py
index e38f36e49d0..a8f8f600a43 100644
--- a/tests/test_activationsd.py
+++ b/tests/test_activationsd.py
@@ -21,7 +21,7 @@
for p in TEST_NDARRAYS:
TEST_CASES.append(
[
- {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None},
+ {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None, "dim": 0},
{"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])},
{"pred": p([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), "label": p([[[0.0, 1.0]], [[2.0, 3.0]]])},
(2, 1, 2),
diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py
index b481601df52..23651c8b6b9 100644
--- a/tests/test_affine_grid.py
+++ b/tests/test_affine_grid.py
@@ -129,7 +129,6 @@
]
)
-
_rtol = 5e-2 if is_tf32_env() else 1e-4
diff --git a/tests/test_apply_filter.py b/tests/test_apply_filter.py
index 3174211f340..62372516a5c 100644
--- a/tests/test_apply_filter.py
+++ b/tests/test_apply_filter.py
@@ -64,7 +64,6 @@ def test_3d(self):
],
]
)
- expected = expected
# testing shapes
k = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
for kernel in (k, k[None], k[None][None]):
diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py
index 867ef84062d..014e439fe1f 100644
--- a/tests/test_as_discrete.py
+++ b/tests/test_as_discrete.py
@@ -29,7 +29,7 @@
TEST_CASES.append(
[
- {"argmax": True, "to_onehot": 2, "threshold": 0.5},
+ {"argmax": True, "to_onehot": 2, "threshold": 0.5, "dim": 0},
p([[[0.0, 1.0]], [[2.0, 3.0]]]),
p([[[0.0, 0.0]], [[1.0, 1.0]]]),
(2, 1, 2),
@@ -69,6 +69,11 @@ def test_value_shape(self, input_param, img, out, expected_shape):
assert_allclose(result, out, rtol=1e-3, type_test="tensor")
self.assertTupleEqual(result.shape, expected_shape)
+ def test_additional(self):
+ for p in TEST_NDARRAYS:
+ out = AsDiscrete(argmax=True, dim=1, keepdim=False)(p([[[0.0, 1.0]], [[2.0, 3.0]]]))
+ assert_allclose(out, p([[0.0, 0.0], [0.0, 0.0]]), type_test=False)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py
index 17527c0fd47..dc96a4218ba 100644
--- a/tests/test_as_discreted.py
+++ b/tests/test_as_discreted.py
@@ -38,7 +38,7 @@
TEST_CASES.append(
[
- {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5},
+ {"keys": ["pred"], "argmax": True, "to_onehot": 2, "threshold": 0.5, "dim": 0, "keepdim": True},
{"pred": p([[[0.0, 1.0]], [[2.0, 3.0]]])},
{"pred": p([[[0.0, 0.0]], [[1.0, 1.0]]])},
(2, 1, 2),
@@ -54,22 +54,6 @@
]
)
- # test compatible with previous versions
- TEST_CASES.append(
- [
- {
- "keys": ["pred", "label"],
- "argmax": False,
- "to_onehot": None,
- "threshold": [True, None],
- "logit_thresh": 0.6,
- },
- {"pred": p([[[0.0, 1.0], [2.0, 3.0]]]), "label": p([[[0, 1], [1, 1]]])},
- {"pred": p([[[0.0, 1.0], [1.0, 1.0]]]), "label": p([[[0.0, 1.0], [1.0, 1.0]]])},
- (1, 2, 2),
- ]
- )
-
# test threshold = 0.0
TEST_CASES.append(
[
diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py
index b2f53f9c16a..e1df7b8acd2 100644
--- a/tests/test_attentionunet.py
+++ b/tests/test_attentionunet.py
@@ -39,7 +39,7 @@ def test_attentionunet(self):
shape = (3, 1) + (92,) * dims
input = torch.rand(*shape)
model = att.AttentionUnet(
- spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
+ spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2)
)
output = model(input)
self.assertEqual(output.shape[2:], input.shape[2:])
diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py
new file mode 100644
index 00000000000..74c45f6fece
--- /dev/null
+++ b/tests/test_auto3dseg.py
@@ -0,0 +1,475 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+import unittest
+from copy import deepcopy
+from numbers import Number
+
+import nibabel as nib
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.apps.auto3dseg import DataAnalyzer
+from monai.auto3dseg import (
+ Analyzer,
+ FgImageStats,
+ FgImageStatsSumm,
+ FilenameStats,
+ ImageStats,
+ ImageStatsSumm,
+ LabelStats,
+ LabelStatsSumm,
+ Operations,
+ SampleOperations,
+ SegSummarizer,
+ SummaryOperations,
+ datafold_read,
+ verify_report_format,
+)
+from monai.bundle import ConfigParser
+from monai.data import DataLoader, Dataset, create_test_image_2d, create_test_image_3d
+from monai.data.meta_tensor import MetaTensor
+from monai.data.utils import no_collation
+from monai.transforms import (
+ Compose,
+ EnsureChannelFirstd,
+ EnsureTyped,
+ Lambdad,
+ LoadImaged,
+ Orientationd,
+ SqueezeDimd,
+ ToDeviced,
+)
+from monai.utils.enums import DataStatsKeys
+from tests.utils import skip_if_no_cuda
+
+device = "cpu"
+n_workers = 2
+
+sim_datalist = {
+ "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+ "training": [
+ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
+ ],
+}
+
+SIM_CPU_TEST_CASES = [
+ [{"sim_dim": (32, 32, 32), "label_key": "label"}],
+ [{"sim_dim": (32, 32, 32, 2), "label_key": "label"}],
+ [{"sim_dim": (32, 32, 32), "label_key": None}],
+ [{"sim_dim": (32, 32, 32), "label_key": "None"}],
+]
+
+SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]]
+
+
+def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None:
+ """
+ Create simulated data using create_test_image_3d.
+
+ Args:
+ dataroot: data directory path that hosts the "nii.gz" image files.
+ sim_datalist: a list of data to create.
+ sim_dim: the image sizes, for examples: a tuple of (64, 64, 64) for 3d, or (128, 128) for 2d
+ """
+ if not os.path.isdir(dataroot):
+ os.makedirs(dataroot)
+
+ # Generate a fake dataset
+ for d in sim_datalist["testing"] + sim_datalist["training"]:
+ if len(sim_dim) == 2: # 2D image
+ im, seg = create_test_image_2d(sim_dim[0], sim_dim[1], **kwargs)
+ elif len(sim_dim) == 3: # 3D image
+ im, seg = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)
+ elif len(sim_dim) == 4: # multi-modality 3D image
+ im_list = []
+ seg_list = []
+ for _ in range(sim_dim[3]):
+ im_3d, seg_3d = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs)
+ im_list.append(im_3d[..., np.newaxis])
+ seg_list.append(seg_3d[..., np.newaxis])
+ im = np.concatenate(im_list, axis=3)
+ seg = np.concatenate(seg_list, axis=3)
+ else:
+ raise ValueError(f"Invalid argument input. sim_dim has f{len(sim_dim)} values. 2-4 values are expected.")
+ nib_image = nib.Nifti1Image(im, affine=np.eye(4))
+ image_fpath = os.path.join(dataroot, d["image"])
+ nib.save(nib_image, image_fpath)
+
+ if not image_only and "label" in d:
+ nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
+ label_fpath = os.path.join(dataroot, d["label"])
+ nib.save(nib_image, label_fpath)
+
+
+class TestOperations(Operations):
+ """
+ Test example for user operation
+ """
+
+ def __init__(self) -> None:
+ self.data = {"max": np.max, "mean": np.mean, "min": np.min}
+
+
+class TestAnalyzer(Analyzer):
+ """
+ Test example for a simple Analyzer
+ """
+
+ def __init__(self, key, report_format, stats_name="test"):
+ self.key = key
+ super().__init__(stats_name, report_format)
+
+ def __call__(self, data):
+ d = dict(data)
+ report = deepcopy(self.get_report_format())
+ report["stats"] = self.ops["stats"].evaluate(d[self.key])
+ d[self.stats_name] = report
+ return d
+
+
+class TestImageAnalyzer(Analyzer):
+ """
+ Test example for a simple Analyzer
+ """
+
+ def __init__(self, image_key="image", stats_name="test_image"):
+
+ self.image_key = image_key
+ report_format = {"test_stats": None}
+
+ super().__init__(stats_name, report_format)
+ self.update_ops("test_stats", TestOperations())
+
+ def __call__(self, data):
+ d = dict(data)
+ report = deepcopy(self.get_report_format())
+ report["test_stats"] = self.ops["test_stats"].evaluate(d[self.image_key])
+ d[self.stats_name] = report
+ return d
+
+
+class TestDataAnalyzer(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.TemporaryDirectory()
+ work_dir = self.test_dir.name
+ self.dataroot_dir = os.path.join(work_dir, "sim_dataroot")
+ self.datalist_file = os.path.join(work_dir, "sim_datalist.json")
+ self.datastat_file = os.path.join(work_dir, "data_stats.yaml")
+ ConfigParser.export_config_file(sim_datalist, self.datalist_file)
+
+ @parameterized.expand(SIM_CPU_TEST_CASES)
+ def test_data_analyzer_cpu(self, input_params):
+
+ sim_dim = input_params["sim_dim"]
+ label_key = input_params["label_key"]
+ image_only = not bool(label_key)
+ rmax = max(int(sim_dim[0] / 4), 1)
+ create_sim_data(
+ self.dataroot_dir, sim_datalist, sim_dim, image_only=image_only, rad_max=rmax, rad_min=1, num_seg_classes=1
+ )
+
+ analyser = DataAnalyzer(
+ self.datalist_file, self.dataroot_dir, output_path=self.datastat_file, label_key=label_key
+ )
+ datastat = analyser.get_all_case_stats()
+
+ assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])
+
+ @parameterized.expand(SIM_GPU_TEST_CASES)
+ @skip_if_no_cuda
+ def test_data_analyzer_gpu(self, input_params):
+ sim_dim = input_params["sim_dim"]
+ label_key = input_params["label_key"]
+ image_only = not bool(label_key)
+ rmax = max(int(sim_dim[0] / 4), 1)
+ create_sim_data(
+ self.dataroot_dir, sim_datalist, sim_dim, image_only=image_only, rad_max=rmax, rad_min=1, num_seg_classes=1
+ )
+ analyser = DataAnalyzer(
+ self.datalist_file, self.dataroot_dir, output_path=self.datastat_file, label_key=label_key, device="cuda"
+ )
+ datastat = analyser.get_all_case_stats()
+
+ assert len(datastat["stats_by_cases"]) == len(sim_datalist["training"])
+
+ def test_basic_operation_class(self):
+ op = TestOperations()
+ test_data = np.random.rand(10, 10).astype(np.float64)
+ test_ret_1 = op.evaluate(test_data)
+ test_ret_2 = op.evaluate(test_data, axis=0)
+ assert isinstance(test_ret_1, dict) and isinstance(test_ret_2, dict)
+ assert ("max" in test_ret_1) and ("max" in test_ret_2)
+ assert ("mean" in test_ret_1) and ("mean" in test_ret_2)
+ assert ("min" in test_ret_1) and ("min" in test_ret_2)
+ assert isinstance(test_ret_1["max"], np.float64)
+ assert isinstance(test_ret_2["max"], np.ndarray)
+ assert test_ret_1["max"].ndim == 0
+ assert test_ret_2["max"].ndim == 1
+
+ def test_sample_operations(self):
+ op = SampleOperations()
+ test_data_np = np.random.rand(10, 10).astype(np.float64)
+ test_data_mt = MetaTensor(test_data_np, device=device)
+ test_ret_np = op.evaluate(test_data_np)
+ test_ret_mt = op.evaluate(test_data_mt)
+ assert isinstance(test_ret_np["max"], Number)
+ assert isinstance(test_ret_np["percentile"], list)
+ assert isinstance(test_ret_mt["max"], Number)
+ assert isinstance(test_ret_mt["percentile"], list)
+
+ op.update({"sum": np.sum})
+ test_ret_np = op.evaluate(test_data_np)
+ assert "sum" in test_ret_np
+
+ def test_summary_operations(self):
+ op = SummaryOperations()
+ test_dict = {"min": [0, 1, 2, 3], "max": [2, 3, 4, 5], "mean": [1, 2, 3, 4], "sum": [2, 4, 6, 8]}
+ test_ret = op.evaluate(test_dict)
+ assert isinstance(test_ret["max"], Number)
+ assert isinstance(test_ret["min"], Number)
+
+ op.update({"sum": np.sum})
+ test_ret = op.evaluate(test_dict)
+ assert "sum" in test_ret
+ assert isinstance(test_ret["sum"], Number)
+
+ def test_basic_analyzer_class(self):
+ test_data = {}
+ test_data["image_test"] = np.random.rand(10, 10)
+ report_format = {"stats": None}
+ user_analyzer = TestAnalyzer("image_test", report_format)
+ user_analyzer.update_ops("stats", TestOperations())
+ result = user_analyzer(test_data)
+ assert result["test"]["stats"]["max"] == np.max(test_data["image_test"])
+ assert result["test"]["stats"]["min"] == np.min(test_data["image_test"])
+ assert result["test"]["stats"]["mean"] == np.mean(test_data["image_test"])
+
+ def test_transform_analyzer_class(self):
+ transform = Compose([LoadImaged(keys=["image"]), TestImageAnalyzer(image_key="image")])
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=0, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ assert "test_image" in d
+ assert "test_stats" in d["test_image"]
+ assert "max" in d["test_image"]["test_stats"]
+ assert "min" in d["test_image"]["test_stats"]
+ assert "mean" in d["test_image"]["test_stats"]
+
+ def test_image_stats_case_analyzer(self):
+ analyzer = ImageStats(image_key="image")
+ transform = Compose(
+ [
+ LoadImaged(keys=["image"]),
+ EnsureChannelFirstd(keys=["image"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image"], device=device, non_blocking=True),
+ Orientationd(keys=["image"], axcodes="RAS"),
+ EnsureTyped(keys=["image"], data_type="tensor"),
+ analyzer,
+ ]
+ )
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ report_format = analyzer.get_report_format()
+ assert verify_report_format(d["image_stats"], report_format)
+
+ def test_foreground_image_stats_cases_analyzer(self):
+ analyzer = FgImageStats(image_key="image", label_key="label")
+ transform_list = [
+ LoadImaged(keys=["image", "label"]),
+ EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ EnsureTyped(keys=["image", "label"], data_type="tensor"),
+ Lambdad(keys=["label"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ analyzer,
+ ]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ report_format = analyzer.get_report_format()
+ assert verify_report_format(d["image_foreground_stats"], report_format)
+
+ def test_label_stats_case_analyzer(self):
+ analyzer = LabelStats(image_key="image", label_key="label")
+ transform = Compose(
+ [
+ LoadImaged(keys=["image", "label"]),
+ EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ EnsureTyped(keys=["image", "label"], data_type="tensor"),
+ Lambdad(keys=["label"], func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ analyzer,
+ ]
+ )
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ report_format = analyzer.get_report_format()
+ assert verify_report_format(d["label_stats"], report_format)
+
+ def test_filename_case_analyzer(self):
+ analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
+ analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH)
+ transform_list = [LoadImaged(keys=["image", "label"]), analyzer_image, analyzer_label]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ assert DataStatsKeys.BY_CASE_IMAGE_PATH in d
+ assert DataStatsKeys.BY_CASE_IMAGE_PATH in d
+
+ def test_filename_case_analyzer_image_only(self):
+ analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
+ analyzer_label = FilenameStats(None, DataStatsKeys.BY_CASE_IMAGE_PATH)
+ transform_list = [LoadImaged(keys=["image"]), analyzer_image, analyzer_label]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ assert DataStatsKeys.BY_CASE_IMAGE_PATH in d
+ assert d[DataStatsKeys.BY_CASE_IMAGE_PATH] == "None"
+
+ def test_image_stats_summary_analyzer(self):
+ summary_analyzer = ImageStatsSumm("image_stats")
+
+ transform_list = [
+ LoadImaged(keys=["image"]),
+ EnsureChannelFirstd(keys=["image"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image"], device=device, non_blocking=True),
+ Orientationd(keys=["image"], axcodes="RAS"),
+ EnsureTyped(keys=["image"], data_type="tensor"),
+ ImageStats(image_key="image"),
+ ]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ stats = []
+ for batch_data in self.dataset:
+ stats.append(transform(batch_data[0]))
+ summary_report = summary_analyzer(stats)
+ report_format = summary_analyzer.get_report_format()
+ assert verify_report_format(summary_report, report_format)
+
+ def test_fg_image_stats_summary_analyzer(self):
+ summary_analyzer = FgImageStatsSumm("image_foreground_stats")
+
+ transform_list = [
+ LoadImaged(keys=["image", "label"]),
+ EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ EnsureTyped(keys=["image", "label"], data_type="tensor"),
+ Lambdad(keys="label", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ FgImageStats(image_key="image", label_key="label"),
+ ]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ stats = []
+ for batch_data in self.dataset:
+ stats.append(transform(batch_data[0]))
+ summary_report = summary_analyzer(stats)
+ report_format = summary_analyzer.get_report_format()
+ assert verify_report_format(summary_report, report_format)
+
+ def test_label_stats_summary_analyzer(self):
+ summary_analyzer = LabelStatsSumm("label_stats")
+
+ transform_list = [
+ LoadImaged(keys=["image", "label"]),
+ EnsureChannelFirstd(keys=["image", "label"]), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=["image", "label"], device=device, non_blocking=True),
+ Orientationd(keys=["image", "label"], axcodes="RAS"),
+ EnsureTyped(keys=["image", "label"], data_type="tensor"),
+ Lambdad(keys="label", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ LabelStats(image_key="image", label_key="label"),
+ ]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ stats = []
+ for batch_data in self.dataset:
+ stats.append(transform(batch_data[0]))
+ summary_report = summary_analyzer(stats)
+ report_format = summary_analyzer.get_report_format()
+ assert verify_report_format(summary_report, report_format)
+
+ def test_seg_summarizer(self):
+ summarizer = SegSummarizer("image", "label")
+ keys = ["image", "label"]
+ transform_list = [
+ LoadImaged(keys=keys),
+ EnsureChannelFirstd(keys=keys), # this creates label to be (1,H,W,D)
+ ToDeviced(keys=keys, device=device, non_blocking=True),
+ Orientationd(keys=keys, axcodes="RAS"),
+ EnsureTyped(keys=keys, data_type="tensor"),
+ Lambdad(keys="label", func=lambda x: torch.argmax(x, dim=0, keepdim=True) if x.shape[0] > 1 else x),
+ SqueezeDimd(keys=["label"], dim=0),
+ summarizer,
+ ]
+ transform = Compose(transform_list)
+ create_sim_data(self.dataroot_dir, sim_datalist, (32, 32, 32), rad_max=8, rad_min=1, num_seg_classes=1)
+ files, _ = datafold_read(sim_datalist, self.dataroot_dir, fold=-1)
+ ds = Dataset(data=files)
+ self.dataset = DataLoader(ds, batch_size=1, shuffle=False, num_workers=n_workers, collate_fn=no_collation)
+ stats = []
+ for batch_data in self.dataset:
+ d = transform(batch_data[0])
+ stats.append(d)
+ report = summarizer.summarize(stats)
+ assert str(DataStatsKeys.IMAGE_STATS) in report
+ assert str(DataStatsKeys.FG_IMAGE_STATS) in report
+ assert str(DataStatsKeys.LABEL_STATS) in report
+
+ def tearDown(self) -> None:
+ self.test_dir.cleanup()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_auto3dseg_ensemble.py b/tests/test_auto3dseg_ensemble.py
new file mode 100644
index 00000000000..7b9656f1ac2
--- /dev/null
+++ b/tests/test_auto3dseg_ensemble.py
@@ -0,0 +1,141 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+import unittest
+from typing import Dict, List
+
+import nibabel as nib
+import numpy as np
+
+from monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer
+from monai.bundle.config_parser import ConfigParser
+from monai.data import create_test_image_3d
+from monai.utils import optional_import
+from monai.utils.enums import AlgoEnsembleKeys
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
+
+_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
+
+fake_datalist: Dict[str, List[Dict]] = {
+ "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+ "training": [
+ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
+ ],
+}
+
+train_param = {
+ "CUDA_VISIBLE_DEVICES": [0],
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+}
+
+pred_param = {"files_slices": slice(0, 1), "mode": "mean", "sigmoid": True}
+
+
+@skip_if_quick
+@SkipIfBeforePyTorchVersion((1, 9, 1))
+@unittest.skipIf(not has_tb, "no tensorboard summary writer")
+class TestEnsembleBuilder(unittest.TestCase):
+ def setUp(self) -> None:
+ self.test_dir = tempfile.TemporaryDirectory()
+
+ @skip_if_no_cuda
+ def test_ensemble(self) -> None:
+ test_path = self.test_dir.name
+
+ dataroot = os.path.join(test_path, "dataroot")
+ work_dir = os.path.join(test_path, "workdir")
+
+ da_output_yaml = os.path.join(work_dir, "datastats.yaml")
+ data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")
+
+ if not os.path.isdir(dataroot):
+ os.makedirs(dataroot)
+
+ if not os.path.isdir(work_dir):
+ os.makedirs(work_dir)
+
+ # Generate a fake dataset
+ for d in fake_datalist["testing"] + fake_datalist["training"]:
+ im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1)
+ nib_image = nib.Nifti1Image(im, affine=np.eye(4))
+ image_fpath = os.path.join(dataroot, d["image"])
+ nib.save(nib_image, image_fpath)
+
+ if "label" in d:
+ nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
+ label_fpath = os.path.join(dataroot, d["label"])
+ nib.save(nib_image, label_fpath)
+
+ # write to a json file
+ fake_json_datalist = os.path.join(dataroot, "fake_input.json")
+ ConfigParser.export_config_file(fake_datalist, fake_json_datalist)
+
+ da = DataAnalyzer(fake_json_datalist, dataroot, output_path=da_output_yaml)
+ da.get_all_case_stats()
+
+ data_src = {
+ "name": "fake_data",
+ "task": "segmentation",
+ "modality": "MRI",
+ "datalist": fake_json_datalist,
+ "dataroot": dataroot,
+ "multigpu": False,
+ "class_names": ["label_class"],
+ }
+
+ ConfigParser.export_config_file(data_src, data_src_cfg)
+
+ with skip_if_downloading_fails():
+ bundle_generator = BundleGen(
+ algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg
+ )
+ bundle_generator.generate(work_dir, num_fold=2)
+ history = bundle_generator.get_history()
+
+ for h in history:
+ self.assertEqual(len(h.keys()), 1, "each record should have one model")
+ for _, algo in h.items():
+ algo.train(train_param)
+
+ builder = AlgoEnsembleBuilder(history, data_src_cfg)
+ builder.set_ensemble_method(AlgoEnsembleBestN(n_best=2))
+ ensemble = builder.get_ensemble()
+ preds = ensemble(pred_param)
+ self.assertTupleEqual(preds[0].shape, (2, 64, 64, 64))
+
+ builder.set_ensemble_method(AlgoEnsembleBestByFold(2))
+ ensemble = builder.get_ensemble()
+ for algo in ensemble.get_algo_ensemble():
+ print(algo[AlgoEnsembleKeys.ID])
+
+ def tearDown(self) -> None:
+ self.test_dir.cleanup()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_auto3dseg_hpo.py b/tests/test_auto3dseg_hpo.py
new file mode 100644
index 00000000000..708828eed4d
--- /dev/null
+++ b/tests/test_auto3dseg_hpo.py
@@ -0,0 +1,222 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import tempfile
+import unittest
+from functools import partial
+from typing import Dict, List
+
+import nibabel as nib
+import numpy as np
+
+from monai.apps.auto3dseg import BundleGen, DataAnalyzer, NNIGen, OptunaGen, import_bundle_algo_history
+from monai.bundle.config_parser import ConfigParser
+from monai.data import create_test_image_3d
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda
+
+_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
+optuna, has_optuna = optional_import("optuna")
+
+
+def skip_if_no_optuna(obj):
+ """
+ Skip the unit tests if torch.cuda.is_available is False.
+ """
+ return unittest.skipUnless(has_optuna, "Skipping optuna tests")(obj)
+
+
+fake_datalist: Dict[str, List[Dict]] = {
+ "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+ "training": [
+ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
+ ],
+}
+
+
+@SkipIfBeforePyTorchVersion((1, 9, 1))
+@unittest.skipIf(not has_tb, "no tensorboard summary writer")
+class TestHPO(unittest.TestCase):
+ def setUp(self) -> None:
+ self.test_dir = tempfile.TemporaryDirectory()
+ test_path = self.test_dir.name
+
+ work_dir = os.path.abspath(os.path.join(test_path, "workdir"))
+ dataroot = os.path.join(work_dir, "dataroot")
+
+ da_output_yaml = os.path.join(work_dir, "datastats.yaml")
+ data_src_cfg = os.path.join(work_dir, "data_src_cfg.yaml")
+
+ if not os.path.isdir(dataroot):
+ os.makedirs(dataroot)
+
+ if not os.path.isdir(work_dir):
+ os.makedirs(work_dir)
+
+ # Generate a fake dataset
+ for d in fake_datalist["testing"] + fake_datalist["training"]:
+ im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1)
+ nib_image = nib.Nifti1Image(im, affine=np.eye(4))
+ image_fpath = os.path.join(dataroot, d["image"])
+ nib.save(nib_image, image_fpath)
+
+ if "label" in d:
+ nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
+ label_fpath = os.path.join(dataroot, d["label"])
+ nib.save(nib_image, label_fpath)
+
+ # write to a json file
+ fake_json_datalist = os.path.join(dataroot, "fake_input.json")
+ ConfigParser.export_config_file(fake_datalist, fake_json_datalist)
+
+ da = DataAnalyzer(fake_json_datalist, dataroot, output_path=da_output_yaml)
+ da.get_all_case_stats()
+
+ data_src = {
+ "name": "fake_data",
+ "task": "segmentation",
+ "modality": "MRI",
+ "datalist": fake_json_datalist,
+ "dataroot": dataroot,
+ "multigpu": False,
+ "class_names": ["label_class"],
+ }
+
+ ConfigParser.export_config_file(data_src, data_src_cfg)
+ with skip_if_downloading_fails():
+ bundle_generator = BundleGen(
+ algo_path=work_dir, data_stats_filename=da_output_yaml, data_src_cfg_name=data_src_cfg
+ )
+ bundle_generator.generate(work_dir, num_fold=2)
+
+ self.history = bundle_generator.get_history()
+ self.work_dir = work_dir
+ self.test_path = test_path
+
+ @skip_if_no_cuda
+ def test_run_algo(self) -> None:
+ override_param = {
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+ }
+
+ algo_dict = self.history[0]
+ algo_name = list(algo_dict.keys())[0]
+ algo = algo_dict[algo_name]
+ nni_gen = NNIGen(algo=algo, params=override_param)
+ obj_filename = nni_gen.get_obj_filename()
+ # this function will be used in HPO via Python Fire
+ NNIGen().run_algo(obj_filename, self.work_dir)
+
+ @skip_if_no_cuda
+ @skip_if_no_optuna
+ def test_run_optuna(self) -> None:
+ override_param = {
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+ }
+
+ algo_dict = self.history[0]
+ algo_name = list(algo_dict.keys())[0]
+ algo = algo_dict[algo_name]
+
+ class OptunaGenLearningRate(OptunaGen):
+ def get_hyperparameters(self):
+ return {"learning_rate": self.trial.suggest_float("learning_rate", 0.00001, 0.1)}
+
+ optuna_gen = OptunaGenLearningRate(algo=algo, params=override_param)
+ search_space = {"learning_rate": [0.0001, 0.001, 0.01, 0.1]}
+ study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space), direction="maximize")
+ study.optimize(
+ partial(
+ optuna_gen,
+ obj_filename=optuna_gen.get_obj_filename(),
+ output_folder=os.path.join(self.test_path, "optuna_test"),
+ ),
+ n_trials=2,
+ )
+ print(f"Best value: {study.best_value} (params: {study.best_params})\n")
+
+ @skip_if_no_cuda
+ def test_run_algo_after_move_files(self) -> None:
+ override_param = {
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+ }
+
+ algo_dict = self.history[0]
+ algo_name = list(algo_dict.keys())[0]
+ algo = algo_dict[algo_name]
+ nni_gen = NNIGen(algo=algo, params=override_param)
+ obj_filename = nni_gen.get_obj_filename()
+
+ work_dir_2 = os.path.join(self.test_path, "workdir2")
+ os.makedirs(work_dir_2)
+ algorithm_template = os.path.join(self.work_dir, "algorithm_templates")
+ algorithm_templates_2 = os.path.join(work_dir_2, "algorithm_templates")
+ algo_dir = os.path.dirname(obj_filename)
+ algo_dir_2 = os.path.join(work_dir_2, os.path.basename(algo_dir))
+
+ obj_filename_2 = os.path.join(algo_dir_2, "algo_object.pkl")
+ shutil.copytree(algorithm_template, algorithm_templates_2)
+ shutil.copytree(algo_dir, algo_dir_2)
+ # this function will be used in HPO via Python Fire in remote
+ NNIGen().run_algo(obj_filename_2, work_dir_2, template_path=algorithm_templates_2)
+
+ @skip_if_no_cuda
+ def test_get_history(self) -> None:
+ override_param = {
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+ }
+
+ algo_dict = self.history[0]
+ algo_name = list(algo_dict.keys())[0]
+ algo = algo_dict[algo_name]
+ nni_gen = NNIGen(algo=algo, params=override_param)
+ obj_filename = nni_gen.get_obj_filename()
+
+ NNIGen().run_algo(obj_filename, self.work_dir)
+
+ history = import_bundle_algo_history(self.work_dir, only_trained=True)
+ assert len(history) == 1
+
+ def tearDown(self) -> None:
+ self.test_dir.cleanup()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_basic_unetplusplus.py b/tests/test_basic_unetplusplus.py
new file mode 100644
index 00000000000..3bca65676a4
--- /dev/null
+++ b/tests/test_basic_unetplusplus.py
@@ -0,0 +1,107 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import BasicUNetPlusPlus
+from tests.utils import test_script_save
+
+CASES_1D = []
+for mode in ["pixelshuffle", "nontrainable", "deconv", None]:
+ kwargs = {"spatial_dims": 1, "in_channels": 5, "out_channels": 8}
+ if mode is not None:
+ kwargs["upsample"] = mode # type: ignore
+ CASES_1D.append([kwargs, (10, 5, 33), (10, 8, 33)])
+
+CASES_2D = []
+for mode in ["pixelshuffle", "nontrainable", "deconv"]:
+ for d1 in range(33, 64, 14):
+ for d2 in range(63, 33, -21):
+ in_channels, out_channels = 2, 3
+ CASES_2D.append(
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "features": (12, 12, 13, 14, 15, 16),
+ "upsample": mode,
+ },
+ (2, in_channels, d1, d2),
+ (2, out_channels, d1, d2),
+ ]
+ )
+CASES_3D = [
+ [ # single channel 3D, batch 2
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 2,
+ "features": (16, 20, 21, 22, 23, 11),
+ "upsample": "pixelshuffle",
+ },
+ (2, 1, 33, 34, 35),
+ (2, 2, 33, 34, 35),
+ ],
+ [ # 2-channel 3D, batch 3
+ {
+ "spatial_dims": 3,
+ "in_channels": 2,
+ "out_channels": 7,
+ "features": (14, 15, 16, 17, 18, 11),
+ "upsample": "deconv",
+ },
+ (3, 2, 33, 37, 34),
+ (3, 7, 33, 37, 34),
+ ],
+ [ # 4-channel 3D, batch 5
+ {
+ "spatial_dims": 3,
+ "in_channels": 4,
+ "out_channels": 2,
+ "features": (14, 15, 16, 17, 18, 10),
+ "upsample": "nontrainable",
+ },
+ (5, 4, 34, 35, 37),
+ (5, 2, 34, 35, 37),
+ ],
+]
+
+
+class TestBasicUNETPlusPlus(unittest.TestCase):
+ @parameterized.expand(CASES_1D + CASES_2D + CASES_3D)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ print(input_param)
+ net = BasicUNetPlusPlus(**input_param).to(device)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+
+ def test_deep_supervision_shape(self):
+ net = BasicUNetPlusPlus(spatial_dims=2, deep_supervision=True, in_channels=3, out_channels=3)
+ test_data = torch.randn(16, 3, 32, 32)
+ with eval_mode(net):
+ result = net(test_data)
+ self.assertEqual(result[0].shape, test_data.shape)
+
+ def test_script(self):
+ net = BasicUNetPlusPlus(spatial_dims=2, deep_supervision=True, in_channels=1, out_channels=3)
+ test_data = torch.randn(16, 1, 32, 32)
+ test_script_save(net, test_data)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py
index a7cbff22f0c..e5847c57ab4 100644
--- a/tests/test_bundle_ckpt_export.py
+++ b/tests/test_bundle_ckpt_export.py
@@ -11,7 +11,6 @@
import json
import os
-import subprocess
import tempfile
import unittest
@@ -20,7 +19,7 @@
from monai.bundle import ConfigParser
from monai.data import load_net_with_metadata
from monai.networks import save_state
-from tests.utils import skip_if_windows
+from tests.utils import command_line_tests, skip_if_windows
TEST_CASE_1 = [""]
@@ -49,7 +48,7 @@ def test_export(self, key_in_ckpt):
cmd = ["coverage", "run", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", ts_file]
cmd += ["--meta_file", meta_file, "--config_file", f"['{config_file}','{def_args_file}']", "--ckpt_file"]
cmd += [ckpt_file, "--key_in_ckpt", key_in_ckpt, "--args_file", def_args_file]
- subprocess.check_call(cmd)
+ command_line_tests(cmd)
self.assertTrue(os.path.exists(ts_file))
_, metadata, extra_files = load_net_with_metadata(
diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py
index 7e609a7b31e..0bb7834dac6 100644
--- a/tests/test_bundle_download.py
+++ b/tests/test_bundle_download.py
@@ -11,7 +11,6 @@
import json
import os
-import subprocess
import tempfile
import unittest
@@ -21,23 +20,29 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, load
-from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick, skip_if_windows
+from tests.utils import (
+ SkipIfBeforePyTorchVersion,
+ assert_allclose,
+ command_line_tests,
+ skip_if_downloading_fails,
+ skip_if_quick,
+ skip_if_windows,
+)
-TEST_CASE_1 = [
- ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
- "test_bundle",
- "Project-MONAI/MONAI-extra-test-data/0.8.1",
- "a131d39a0af717af32d19e565b434928",
-]
+TEST_CASE_1 = ["test_bundle", None]
+
+TEST_CASE_2 = ["test_bundle_v0.1.1", None]
-TEST_CASE_2 = [
+TEST_CASE_3 = ["test_bundle", "0.1.1"]
+
+TEST_CASE_4 = [
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
"test_bundle",
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle.zip",
"a131d39a0af717af32d19e565b434928",
]
-TEST_CASE_3 = [
+TEST_CASE_5 = [
["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"],
"test_bundle",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
@@ -45,9 +50,10 @@
"model.pt",
]
-TEST_CASE_4 = [
+TEST_CASE_6 = [
["test_output.pt", "test_input.pt"],
"test_bundle",
+ "0.1.1",
"Project-MONAI/MONAI-extra-test-data/0.8.1",
"cuda" if torch.cuda.is_available() else "cpu",
"model.ts",
@@ -56,22 +62,27 @@
@skip_if_windows
class TestDownload(unittest.TestCase):
- @parameterized.expand([TEST_CASE_1])
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@skip_if_quick
- def test_download_bundle(self, bundle_files, bundle_name, repo, hash_val):
+ def test_download_bundle(self, bundle_name, version):
+ bundle_files = ["model.pt", "model.ts", "network.json", "test_output.pt", "test_input.pt"]
+ repo = "Project-MONAI/MONAI-extra-test-data/0.8.1"
+ hash_val = "a131d39a0af717af32d19e565b434928"
with skip_if_downloading_fails():
# download a whole bundle from github releases
with tempfile.TemporaryDirectory() as tempdir:
cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--name", bundle_name, "--source", "github"]
cmd += ["--bundle_dir", tempdir, "--repo", repo, "--progress", "False"]
- subprocess.check_call(cmd)
+ if version is not None:
+ cmd += ["--version", version]
+ command_line_tests(cmd)
for file in bundle_files:
- file_path = os.path.join(tempdir, bundle_name, file)
+ file_path = os.path.join(tempdir, "test_bundle", file)
self.assertTrue(os.path.exists(file_path))
if file == "network.json":
self.assertTrue(check_hash(filepath=file_path, val=hash_val))
- @parameterized.expand([TEST_CASE_2])
+ @parameterized.expand([TEST_CASE_4])
@skip_if_quick
def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
with skip_if_downloading_fails():
@@ -83,7 +94,7 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
parser.export_config_file(config=def_args, filepath=def_args_file)
cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file]
cmd += ["--url", url]
- subprocess.check_call(cmd)
+ command_line_tests(cmd)
for file in bundle_files:
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))
@@ -92,7 +103,7 @@ def test_url_download_bundle(self, bundle_files, bundle_name, url, hash_val):
class TestLoad(unittest.TestCase):
- @parameterized.expand([TEST_CASE_3])
+ @parameterized.expand([TEST_CASE_5])
@skip_if_quick
def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file):
with skip_if_downloading_fails():
@@ -122,7 +133,7 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[4]), map_location=device)
output = model.forward(input_tensor)
expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[3]), map_location=device)
- torch.testing.assert_allclose(output, expected_output)
+ assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
# load instantiated model directly and test, since the bundle has been downloaded,
# there is no need to input `repo`
@@ -137,18 +148,19 @@ def test_load_weights(self, bundle_files, bundle_name, repo, device, model_file)
)
model_2.eval()
output_2 = model_2.forward(input_tensor)
- torch.testing.assert_allclose(output_2, expected_output)
+ assert_allclose(output_2, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
- @parameterized.expand([TEST_CASE_4])
+ @parameterized.expand([TEST_CASE_6])
@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 7, 1))
- def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_file):
+ def test_load_ts_module(self, bundle_files, bundle_name, version, repo, device, model_file):
with skip_if_downloading_fails():
# load ts module
with tempfile.TemporaryDirectory() as tempdir:
# load ts module
model_ts, metadata, extra_file_dict = load(
name=bundle_name,
+ version=version,
model_file=model_file,
load_ts_module=True,
bundle_dir=tempdir,
@@ -162,7 +174,7 @@ def test_load_ts_module(self, bundle_files, bundle_name, repo, device, model_fil
input_tensor = torch.load(os.path.join(tempdir, bundle_name, bundle_files[1]), map_location=device)
output = model_ts.forward(input_tensor)
expected_output = torch.load(os.path.join(tempdir, bundle_name, bundle_files[0]), map_location=device)
- torch.testing.assert_allclose(output, expected_output)
+ assert_allclose(output, expected_output, atol=1e-4, rtol=1e-4, type_test=False)
# test metadata
self.assertTrue(metadata["pytorch_version"] == "1.7.1")
# test extra_file_dict
diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py
new file mode 100644
index 00000000000..c36409f724c
--- /dev/null
+++ b/tests/test_bundle_get_data.py
@@ -0,0 +1,67 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from parameterized import parameterized
+
+from monai.bundle import get_all_bundles_list, get_bundle_info, get_bundle_versions
+from monai.utils import optional_import
+from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, skip_if_windows
+
+requests, _ = optional_import("requests")
+
+TEST_CASE_1 = [{"bundle_name": "brats_mri_segmentation"}]
+
+TEST_CASE_2 = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": None}]
+
+TEST_CASE_FAKE_TOKEN = [{"bundle_name": "spleen_ct_segmentation", "version": "0.1.0", "auth_token": "ghp_errortoken"}]
+
+
+@skip_if_windows
+@SkipIfNoModule("requests")
+class TestGetBundleData(unittest.TestCase):
+ @skip_if_quick
+ def test_get_all_bundles_list(self):
+ with skip_if_downloading_fails():
+ output = get_all_bundles_list()
+ self.assertTrue(isinstance(output, list))
+ self.assertTrue(isinstance(output[0], tuple))
+ self.assertTrue(len(output[0]) == 2)
+
+ @parameterized.expand([TEST_CASE_1])
+ @skip_if_quick
+ def test_get_bundle_versions(self, params):
+ with skip_if_downloading_fails():
+ output = get_bundle_versions(**params)
+ self.assertTrue(isinstance(output, dict))
+ self.assertTrue("latest_version" in output and "all_versions" in output)
+ self.assertTrue("0.1.0" in output["all_versions"])
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ @skip_if_quick
+ def test_get_bundle_info(self, params):
+ with skip_if_downloading_fails():
+ output = get_bundle_info(**params)
+ self.assertTrue(isinstance(output, dict))
+ for key in ["id", "name", "size", "download_count", "browser_download_url"]:
+ self.assertTrue(key in output)
+
+ @parameterized.expand([TEST_CASE_FAKE_TOKEN])
+ @skip_if_quick
+ def test_fake_token(self, params):
+ with skip_if_downloading_fails():
+ with self.assertRaises(requests.exceptions.HTTPError):
+ get_bundle_info(**params)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_bundle_init_bundle.py b/tests/test_bundle_init_bundle.py
index 24fc425c316..6f88d239ffd 100644
--- a/tests/test_bundle_init_bundle.py
+++ b/tests/test_bundle_init_bundle.py
@@ -10,14 +10,13 @@
# limitations under the License.
import os
-import subprocess
import tempfile
import unittest
import torch
from monai.networks.nets import UNet
-from tests.utils import skip_if_windows
+from tests.utils import command_line_tests, skip_if_windows
@skip_if_windows
@@ -30,7 +29,7 @@ def test_bundle(self):
bundle_root = tempdir + "/test_bundle"
cmd = ["coverage", "run", "-m", "monai.bundle", "init_bundle", bundle_root, tempdir + "/test.pt"]
- subprocess.check_call(cmd)
+ command_line_tests(cmd)
self.assertTrue(os.path.exists(bundle_root + "/configs/metadata.json"))
self.assertTrue(os.path.exists(bundle_root + "/configs/inference.json"))
diff --git a/tests/test_bundle_utils.py b/tests/test_bundle_utils.py
index 46b29651cd5..0fbfae5094f 100644
--- a/tests/test_bundle_utils.py
+++ b/tests/test_bundle_utils.py
@@ -11,7 +11,6 @@
import os
import shutil
-import subprocess
import tempfile
import unittest
@@ -19,7 +18,7 @@
from monai.bundle.utils import load_bundle_config
from monai.networks.nets import UNet
-from tests.utils import skip_if_windows
+from tests.utils import command_line_tests, skip_if_windows
metadata = """
{
@@ -97,6 +96,9 @@ def test_load_config_zip(self):
self.assertEqual(p["test_dict"]["b"], "c")
+ def test_run(self):
+ command_line_tests(["python", "-m", "monai.bundle", "run", "test", "--test", "$print('hello world')"])
+
def test_load_config_ts(self):
# create a Torchscript zip of the bundle
cmd = ["python", "-m", "monai.bundle", "ckpt_export", "network_def", "--filepath", self.ts_file]
@@ -104,7 +106,7 @@ def test_load_config_ts(self):
cmd += ["--config_file", self.test_name]
cmd += ["--ckpt_file", self.modelpt_name]
- subprocess.check_output(cmd, stderr=subprocess.STDOUT)
+ command_line_tests(cmd)
p = load_bundle_config(self.ts_file, "test.json")
diff --git a/tests/test_bundle_verify_metadata.py b/tests/test_bundle_verify_metadata.py
index a8efa4eac9c..773cb40888a 100644
--- a/tests/test_bundle_verify_metadata.py
+++ b/tests/test_bundle_verify_metadata.py
@@ -11,14 +11,13 @@
import json
import os
-import subprocess
import tempfile
import unittest
from parameterized import parameterized
from monai.bundle import ConfigParser, verify_metadata
-from tests.utils import download_url_or_skip_test, skip_if_windows, testing_data_config
+from tests.utils import command_line_tests, download_url_or_skip_test, skip_if_windows, testing_data_config
SCHEMA_FILE = os.path.join(os.path.dirname(__file__), "testing_data", "schema.json")
@@ -45,7 +44,7 @@ def test_verify(self, meta_file, schema_file):
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_metadata", "--meta_file", meta_file]
cmd += ["--filepath", schema_file, "--hash_val", self.config["hash_val"], "--args_file", def_args_file]
- subprocess.check_call(cmd)
+ command_line_tests(cmd)
def test_verify_error(self):
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py
index 05a0731b43d..e2be20a32b3 100644
--- a/tests/test_bundle_verify_net.py
+++ b/tests/test_bundle_verify_net.py
@@ -10,14 +10,13 @@
# limitations under the License.
import os
-import subprocess
import tempfile
import unittest
from parameterized import parameterized
from monai.bundle import ConfigParser
-from tests.utils import skip_if_windows
+from tests.utils import command_line_tests, skip_if_windows
TEST_CASE_1 = [
os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json"),
@@ -36,11 +35,8 @@ def test_verify(self, meta_file, config_file):
cmd = ["coverage", "run", "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file"]
cmd += [meta_file, "--config_file", config_file, "-n", "4", "--any", "16", "--args_file", def_args_file]
- cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"]
-
- test_env = os.environ.copy()
- print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
- subprocess.check_call(cmd, env=test_env)
+ cmd += ["--device", "cpu", "--_meta_#network_data_format#inputs#image#spatial_shape", "[16,'*','2**p*n']"]
+ command_line_tests(cmd)
if __name__ == "__main__":
diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py
index 4fa1b5ea695..86ebced9f31 100644
--- a/tests/test_cachedataset.py
+++ b/tests/test_cachedataset.py
@@ -26,7 +26,6 @@
TEST_CASE_2 = [None, (128, 128, 128)]
-
TEST_DS = []
for c in (0, 1, 2):
for l in (0, 1, 2):
@@ -203,6 +202,7 @@ def test_hash_as_key(self, transform, expected_shape):
self.assertEqual(len(dataset), 5)
# ensure no duplicated cache content
self.assertEqual(len(dataset._cache), 3)
+ self.assertEqual(len(dataset._hash_keys), 3)
self.assertEqual(dataset.cache_num, 3)
data1 = dataset[0]
data2 = dataset[1]
diff --git a/tests/test_cast_to_typed.py b/tests/test_cast_to_typed.py
index 4c7623a9e05..1ac23314a55 100644
--- a/tests/test_cast_to_typed.py
+++ b/tests/test_cast_to_typed.py
@@ -36,7 +36,6 @@
{"img": torch.float64, "seg": torch.int8},
]
-
TESTS_CUPY = [
[
{"keys": "image", "dtype": np.uint8},
diff --git a/tests/test_center_scale_crop.py b/tests/test_center_scale_crop.py
index ab07a44eb55..3fe7a453d31 100644
--- a/tests/test_center_scale_crop.py
+++ b/tests/test_center_scale_crop.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import numpy as np
diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py
index 894692530d5..088c1c70e7c 100644
--- a/tests/test_center_scale_cropd.py
+++ b/tests/test_center_scale_cropd.py
@@ -23,7 +23,6 @@
[{"keys": "img", "roi_scale": 0.5}, (3, 3, 3, 3), (3, 2, 2, 2)],
]
-
TEST_VALUES = [
[
{"keys": "img", "roi_scale": [0.4, 0.4]},
diff --git a/tests/test_compose.py b/tests/test_compose.py
index 4d1bcfe01c4..e322e216ad5 100644
--- a/tests/test_compose.py
+++ b/tests/test_compose.py
@@ -169,16 +169,16 @@ def test_data_loader(self):
set_determinism(seed=123)
train_loader = DataLoader(train_ds, num_workers=0)
out_1 = next(iter(train_loader))
- self.assertAlmostEqual(out_1.cpu().item(), 0.84291356)
+ self.assertAlmostEqual(out_1.cpu().item(), 0.0409280)
if sys.platform != "win32": # skip multi-worker tests on win32
train_loader = DataLoader(train_ds, num_workers=1)
out_1 = next(iter(train_loader))
- self.assertAlmostEqual(out_1.cpu().item(), 0.180814653)
+ self.assertAlmostEqual(out_1.cpu().item(), 0.78663897075)
train_loader = DataLoader(train_ds, num_workers=2)
out_1 = next(iter(train_loader))
- self.assertAlmostEqual(out_1.cpu().item(), 0.04293707)
+ self.assertAlmostEqual(out_1.cpu().item(), 0.785907334)
set_determinism(None)
def test_data_loader_2(self):
@@ -191,16 +191,16 @@ def test_data_loader_2(self):
train_loader = DataLoader(train_ds, num_workers=0)
out_2 = next(iter(train_loader))
- self.assertAlmostEqual(out_2.cpu().item(), 0.7858843729)
+ self.assertAlmostEqual(out_2.cpu().item(), 0.98921915918)
if sys.platform != "win32": # skip multi-worker tests on win32
train_loader = DataLoader(train_ds, num_workers=1)
out_2 = next(iter(train_loader))
- self.assertAlmostEqual(out_2.cpu().item(), 0.305763411)
+ self.assertAlmostEqual(out_2.cpu().item(), 0.32985207)
train_loader = DataLoader(train_ds, num_workers=2)
out_1 = next(iter(train_loader))
- self.assertAlmostEqual(out_1.cpu().item(), 0.131966779)
+ self.assertAlmostEqual(out_1.cpu().item(), 0.28602141572)
set_determinism(None)
def test_flatten_and_len(self):
diff --git a/tests/test_compute_confusion_matrix.py b/tests/test_compute_confusion_matrix.py
index 0e38357d12b..56d619a289a 100644
--- a/tests/test_compute_confusion_matrix.py
+++ b/tests/test_compute_confusion_matrix.py
@@ -12,7 +12,6 @@
import unittest
from typing import Any, Dict, List
-import numpy as np
import torch
from parameterized import parameterized
@@ -22,20 +21,24 @@
do_metric_reduction,
get_confusion_matrix,
)
+from tests.utils import assert_allclose
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# input data
data: Dict[Any, Any] = {
"y_pred": torch.tensor(
[
[[[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 0.0]]],
- ]
+ ],
+ device=_device,
),
"y": torch.tensor(
[
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]]],
[[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
- ]
+ ],
+ device=_device,
),
}
@@ -141,12 +144,12 @@
"mk",
]
result: Any = None
-for idx in range(len(metric_names)):
+for idx, item in enumerate(metric_names):
for reduction in ["mean", "mean_batch"]:
TEST_CASE: List[Any] = [data.copy()]
TEST_CASE[0]["compute_sample"] = True
TEST_CASE[0]["include_background"] = True
- TEST_CASE[0]["metric_name"] = metric_names[idx]
+ TEST_CASE[0]["metric_name"] = item
TEST_CASE[0]["reduction"] = reduction
TEST_CASE[0]["get_not_nans"] = True
if reduction == "mean_batch":
@@ -211,10 +214,10 @@ def test_value(self, input_data, expected_value):
# include or ignore background
input_data["include_background"] = True
result = get_confusion_matrix(**input_data)
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
input_data["include_background"] = False
result = get_confusion_matrix(**input_data)
- np.testing.assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value[:, 1:, :], atol=1e-4, rtol=1e-4)
@parameterized.expand(TEST_CASES_COMPUTE_SAMPLE)
def test_compute_sample(self, input_data, expected_value):
@@ -225,7 +228,7 @@ def test_compute_sample(self, input_data, expected_value):
metric = ConfusionMatrixMetric(**params)
metric(**vals)
result, _ = metric.aggregate()[0]
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
@parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_MULTI_METRICS)
def test_compute_sample_multiple_metrics(self, input_data, expected_values):
@@ -236,10 +239,10 @@ def test_compute_sample_multiple_metrics(self, input_data, expected_values):
metric = ConfusionMatrixMetric(**params)
metric(**vals)
results = metric.aggregate()
- for idx in range(len(results)):
- result = results[idx][0]
+ for idx, item in enumerate(results):
+ result = item[0]
expected_value = expected_values[idx]
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
@parameterized.expand(TEST_CASES_COMPUTE_SAMPLE_NAN)
def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_nans):
@@ -250,8 +253,8 @@ def test_compute_sample_with_nan(self, input_data, expected_value, expected_not_
metric = ConfusionMatrixMetric(**params)
metric(**vals)
result, not_nans = metric.aggregate()[0]
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
- np.testing.assert_allclose(not_nans, expected_not_nans, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(not_nans, expected_not_nans, atol=1e-4, rtol=1e-4)
@parameterized.expand([TEST_CASES_CLF])
def test_clf_with_nan(self, input_data, expected_value):
@@ -261,11 +264,11 @@ def test_clf_with_nan(self, input_data, expected_value):
vals["y"] = params.pop("y")
metric = ConfusionMatrixMetric(**params)
result = metric(**vals)
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
result, _ = metric.aggregate(reduction="mean_channel")[0]
expected_value, _ = do_metric_reduction(expected_value, "mean_channel")
expected_value = compute_confusion_matrix_metric("tpr", expected_value)
- np.testing.assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
+ assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
diff --git a/tests/test_compute_froc.py b/tests/test_compute_froc.py
index d68f3f7fb49..91c1ea1977a 100644
--- a/tests/test_compute_froc.py
+++ b/tests/test_compute_froc.py
@@ -17,11 +17,12 @@
from monai.metrics import compute_fp_tp_probs, compute_froc_curve_data, compute_froc_score
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
TEST_CASE_1 = [
{
- "probs": torch.tensor([1, 0.6, 0.8]),
- "y_coord": torch.tensor([0, 2, 3]),
- "x_coord": torch.tensor([3, 0, 1]),
+ "probs": torch.tensor([1, 0.6, 0.8], device=_device),
+ "y_coord": torch.tensor([0, 2, 3], device=_device),
+ "x_coord": torch.tensor([3, 0, 1], device=_device),
"evaluation_mask": np.array([[0, 0, 1, 1], [2, 2, 0, 0], [0, 3, 3, 0], [0, 3, 3, 3]]),
"labels_to_exclude": [2],
"resolution_level": 0,
diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py
new file mode 100644
index 00000000000..5c4674dd048
--- /dev/null
+++ b/tests/test_compute_ho_ver_maps.py
@@ -0,0 +1,73 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms.intensity.array import ComputeHoVerMaps
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.0", min_version)
+
+INSTANCE_MASK = np.zeros((1, 16, 16), dtype="int16")
+INSTANCE_MASK[:, 5:8, 4:11] = 1
+INSTANCE_MASK[:, 3:5, 6:9] = 1
+INSTANCE_MASK[:, 8:10, 6:9] = 1
+INSTANCE_MASK[:, 13:, 13:] = 2
+H_MAP = torch.zeros((16, 16), dtype=torch.float32)
+H_MAP[5:8, 4] = -1.0
+H_MAP[5:8, 5] = -2.0 / 3.0
+H_MAP[3:10, 6] = -1.0 / 3.0
+H_MAP[3:10, 7] = 0.0
+H_MAP[3:10, 8] = 1.0 / 3.0
+H_MAP[5:8, 9] = 2.0 / 3.0
+H_MAP[5:8, 10] = 1.0
+H_MAP[13:, 13] = -1.0
+H_MAP[13:, 14] = 0.0
+H_MAP[13:, 15] = 1.0
+V_MAP = torch.zeros((16, 16), dtype=torch.float32)
+V_MAP[3, 6:9] = -1.0
+V_MAP[4, 6:9] = -2.0 / 3.0
+V_MAP[5, 4:11] = -1.0 / 3.0
+V_MAP[6, 4:11] = 0.0
+V_MAP[7, 4:11] = 1.0 / 3.0
+V_MAP[8, 6:9] = 2.0 / 3.0
+V_MAP[9, 6:9] = 1.0
+V_MAP[13, 13:] = -1.0
+V_MAP[14, 13:] = 0.0
+V_MAP[15, 13:] = 1.0
+HV_MAPS = torch.stack([H_MAP, V_MAP])
+TEST_CASE_0 = [{}, INSTANCE_MASK, HV_MAPS]
+TEST_CASE_1 = [{"dtype": "float64"}, INSTANCE_MASK, HV_MAPS]
+
+TESTS = []
+for p in TEST_NDARRAYS:
+ TESTS.append([p, *TEST_CASE_0])
+ TESTS.append([p, *TEST_CASE_1])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class ComputeHoVerMapsTests(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
+ input_image = in_type(mask)
+ result = ComputeHoVerMaps(**arguments)(input_image)
+ self.assertTrue(isinstance(result, torch.Tensor))
+ self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32"))
+ assert_allclose(result, hv_mask, type_test="tensor")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py
new file mode 100644
index 00000000000..475e50bc703
--- /dev/null
+++ b/tests/test_compute_ho_ver_maps_d.py
@@ -0,0 +1,77 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms.intensity.dictionary import ComputeHoVerMapsd
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.0", min_version)
+
+INSTANCE_MASK = np.zeros((1, 16, 16), dtype="int16")
+INSTANCE_MASK[:, 5:8, 4:11] = 1
+INSTANCE_MASK[:, 3:5, 6:9] = 1
+INSTANCE_MASK[:, 8:10, 6:9] = 1
+INSTANCE_MASK[:, 13:, 13:] = 2
+H_MAP = torch.zeros((16, 16), dtype=torch.float32)
+H_MAP[5:8, 4] = -1.0
+H_MAP[5:8, 5] = -2.0 / 3.0
+H_MAP[3:10, 6] = -1.0 / 3.0
+H_MAP[3:10, 7] = 0.0
+H_MAP[3:10, 8] = 1.0 / 3.0
+H_MAP[5:8, 9] = 2.0 / 3.0
+H_MAP[5:8, 10] = 1.0
+H_MAP[13:, 13] = -1.0
+H_MAP[13:, 14] = 0.0
+H_MAP[13:, 15] = 1.0
+V_MAP = torch.zeros((16, 16), dtype=torch.float32)
+V_MAP[3, 6:9] = -1.0
+V_MAP[4, 6:9] = -2.0 / 3.0
+V_MAP[5, 4:11] = -1.0 / 3.0
+V_MAP[6, 4:11] = 0.0
+V_MAP[7, 4:11] = 1.0 / 3.0
+V_MAP[8, 6:9] = 2.0 / 3.0
+V_MAP[9, 6:9] = 1.0
+V_MAP[13, 13:] = -1.0
+V_MAP[14, 13:] = 0.0
+V_MAP[15, 13:] = 1.0
+HV_MAPS = torch.stack([H_MAP, V_MAP])
+TEST_CASE_0 = [{}, {"mask": INSTANCE_MASK}, {"hover_mask": HV_MAPS}]
+TEST_CASE_1 = [{"dtype": "float64"}, {"mask": INSTANCE_MASK}, {"hover_mask": HV_MAPS}]
+TEST_CASE_1 = [{"dtype": "float64", "new_key_prefix": ""}, {"mask": INSTANCE_MASK}, {"mask": HV_MAPS}]
+
+TESTS = []
+for p in TEST_NDARRAYS:
+ TESTS.append([p, *TEST_CASE_0])
+ TESTS.append([p, *TEST_CASE_1])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class ComputeHoVerMapsDictTests(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
+ hv_key = list(hv_mask.keys())[0]
+ input_image = {}
+ for k in mask.keys():
+ input_image[k] = in_type(mask[k])
+ result = ComputeHoVerMapsd(keys="mask", **arguments)(input_image)[hv_key]
+ self.assertTrue(isinstance(result, torch.Tensor))
+ self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32"))
+ assert_allclose(result, hv_mask[hv_key], type_test="tensor")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_compute_meandice.py b/tests/test_compute_meandice.py
index c925c6f148b..4dd5d77c4fb 100644
--- a/tests/test_compute_meandice.py
+++ b/tests/test_compute_meandice.py
@@ -15,13 +15,14 @@
import torch
from parameterized import parameterized
-from monai.metrics import DiceMetric, compute_meandice
+from monai.metrics import DiceMetric, compute_dice, compute_meandice
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)
{
- "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
- "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
+ "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
},
[[0.8]],
@@ -186,7 +187,7 @@
class TestComputeMeanDice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_9, TEST_CASE_11, TEST_CASE_12])
def test_value(self, input_data, expected_value):
- result = compute_meandice(**input_data)
+ result = compute_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
@parameterized.expand([TEST_CASE_3])
diff --git a/tests/test_compute_meaniou.py b/tests/test_compute_meaniou.py
index d6ff95dc74e..52a0223a2d4 100644
--- a/tests/test_compute_meaniou.py
+++ b/tests/test_compute_meaniou.py
@@ -15,13 +15,14 @@
import torch
from parameterized import parameterized
-from monai.metrics import MeanIoU, compute_meaniou
+from monai.metrics import MeanIoU, compute_iou, compute_meaniou
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1)
{
- "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]]),
- "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
+ "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
},
[[0.6667]],
@@ -191,7 +192,7 @@ def test_value(self, input_data, expected_value):
@parameterized.expand([TEST_CASE_3])
def test_nans(self, input_data, expected_value):
- result = compute_meaniou(**input_data)
+ result = compute_iou(**input_data)
self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
# MeanIoU class tests
diff --git a/tests/test_compute_panoptic_quality.py b/tests/test_compute_panoptic_quality.py
new file mode 100644
index 00000000000..cf5d0deb2a7
--- /dev/null
+++ b/tests/test_compute_panoptic_quality.py
@@ -0,0 +1,111 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from typing import List
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.metrics import PanopticQualityMetric, compute_panoptic_quality
+from tests.utils import SkipIfNoModule
+
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+# TEST_FUNC_CASE related cases are used to test for single image with HW input shape
+
+sample_1 = torch.randint(low=0, high=5, size=(64, 64), device=_device)
+sample_2_pred = torch.as_tensor([[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]], device=_device)
+sample_2_pred_need_remap = torch.as_tensor([[0, 7, 7, 7], [0, 0, 0, 0], [1, 0, 8, 8], [9, 1, 1, 0]], device=_device)
+sample_2_gt = torch.as_tensor([[1, 1, 2, 1], [0, 0, 0, 0], [1, 3, 0, 0], [4, 3, 3, 3]], device=_device)
+# if pred == gt, result should be 1
+TEST_FUNC_CASE_1 = [{"pred": sample_1, "gt": sample_1, "match_iou_threshold": 0.99}, 1.0]
+
+# test sample_2 when match_iou_threshold = 0.5
+TEST_FUNC_CASE_2 = [{"pred": sample_2_pred, "gt": sample_2_gt, "match_iou_threshold": 0.5}, 0.25]
+# test sample_2 when match_iou_threshold = 0.3, metric_name = "sq"
+TEST_FUNC_CASE_3 = [{"pred": sample_2_pred, "gt": sample_2_gt, "metric_name": "sq", "match_iou_threshold": 0.3}, 0.6]
+# test sample_2 when match_iou_threshold = 0.3, pred has different order, metric_name = "RQ"
+TEST_FUNC_CASE_4 = [
+ {"pred": sample_2_pred_need_remap, "gt": sample_2_gt, "metric_name": "RQ", "match_iou_threshold": 0.3},
+ 0.75,
+]
+
+# TEST_CLS_CASE related cases are used to test the PanopticQualityMetric with B2HW input
+sample_3_pred = torch.as_tensor(
+ [
+ [[[2, 0, 1], [2, 1, 1], [0, 1, 1]], [[0, 1, 3], [0, 0, 0], [1, 2, 1]]],
+ [[[1, 1, 1], [3, 2, 0], [3, 2, 1]], [[1, 1, 3], [3, 1, 1], [0, 3, 0]]],
+ ],
+ device=_device,
+)
+
+sample_3_gt = torch.as_tensor(
+ [
+ [[[2, 0, 0], [2, 0, 0], [2, 2, 3]], [[3, 3, 3], [3, 2, 1], [2, 2, 3]]],
+ [[[1, 1, 1], [0, 0, 3], [0, 0, 3]], [[0, 1, 3], [2, 1, 0], [3, 0, 3]]],
+ ],
+ device=_device,
+)
+
+# test sample_3, num_classes = 3, match_iou_threshold = 0.5
+TEST_CLS_CASE_1 = [{"num_classes": 3, "match_iou_threshold": 0.5}, sample_3_pred, sample_3_gt, (0.0, 0.0, 0.25)]
+
+# test sample_3, num_classes = 3, match_iou_threshold = 0.3
+TEST_CLS_CASE_2 = [{"num_classes": 3, "match_iou_threshold": 0.3}, sample_3_pred, sample_3_gt, (0.25, 0.5, 0.25)]
+
+# test sample_3, num_classes = 4, match_iou_threshold = 0.3, metric_name = "segmentation_quality"
+TEST_CLS_CASE_3 = [
+ {"num_classes": 4, "match_iou_threshold": 0.3, "metric_name": "segmentation_quality"},
+ sample_3_pred,
+ sample_3_gt,
+ (0.5, 0.5, 1.0, 0.0),
+]
+
+# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = "none", metric_name = "Recognition Quality"
+TEST_CLS_CASE_4 = [
+ {"num_classes": 3, "reduction": "none", "match_iou_threshold": 0.4, "metric_name": "Recognition Quality"},
+ sample_3_pred,
+ sample_3_gt,
+ [[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]],
+]
+
+# test sample_3, num_classes = 3, match_iou_threshold = 0.4, reduction = "none", multiple metrics
+TEST_CLS_CASE_5 = [
+ {"num_classes": 3, "reduction": "none", "match_iou_threshold": 0.4, "metric_name": ["Recognition Quality", "pq"]},
+ sample_3_pred,
+ sample_3_gt,
+ [torch.as_tensor([[0.0, 1.0, 0.0], [0.6667, 0.0, 0.4]]), torch.as_tensor([[0.0, 0.5, 0.0], [0.3333, 0.0, 0.4]])],
+]
+
+
+@SkipIfNoModule("scipy.optimize")
+class TestPanopticQualityMetric(unittest.TestCase):
+ @parameterized.expand([TEST_FUNC_CASE_1, TEST_FUNC_CASE_2, TEST_FUNC_CASE_3, TEST_FUNC_CASE_4])
+ def test_value(self, input_params, expected_value):
+ result = compute_panoptic_quality(**input_params)
+ np.testing.assert_allclose(result.cpu().detach().item(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CLS_CASE_1, TEST_CLS_CASE_2, TEST_CLS_CASE_3, TEST_CLS_CASE_4, TEST_CLS_CASE_5])
+ def test_value_class(self, input_params, y_pred, y_gt, expected_value):
+ metric = PanopticQualityMetric(**input_params)
+ metric(y_pred, y_gt)
+ outputs = metric.aggregate()
+ if isinstance(outputs, List):
+ for output, value in zip(outputs, expected_value):
+ np.testing.assert_allclose(output.cpu().numpy(), np.asarray(value), atol=1e-4)
+ else:
+ np.testing.assert_allclose(outputs.cpu().numpy(), np.asarray(expected_value), atol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py
index 2c9135024fa..0e57f1fe4a6 100644
--- a/tests/test_compute_roc_auc.py
+++ b/tests/test_compute_roc_auc.py
@@ -19,9 +19,10 @@
from monai.metrics import ROCAUCMetric, compute_roc_auc
from monai.transforms import Activations, AsDiscrete, Compose, ToTensor
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
TEST_CASE_1 = [
- torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]]),
- torch.tensor([[0], [1], [0], [1]]),
+ torch.tensor([[0.1, 0.9], [0.3, 1.4], [0.2, 0.1], [0.1, 0.5]], device=_device),
+ torch.tensor([[0], [1], [0], [1]], device=_device),
True,
2,
"macro",
diff --git a/tests/test_compute_variance.py b/tests/test_compute_variance.py
new file mode 100644
index 00000000000..2743fcdc798
--- /dev/null
+++ b/tests/test_compute_variance.py
@@ -0,0 +1,143 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.metrics import VarianceMetric, compute_variance
+
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+# keep background, 1D Case
+TEST_CASE_1 = [ # y_pred (3, 1, 3), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),
+ "include_background": True,
+ "spatial_map": False,
+ },
+ [[0.0]],
+]
+
+# keep background, 2D Case
+TEST_CASE_2 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "include_background": True,
+ "spatial_map": False,
+ },
+ [[0.0]],
+]
+
+# keep background, 3D Case
+TEST_CASE_3 = [ # y_pred (1, 1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "include_background": True,
+ "spatial_map": False,
+ },
+ [[0.0]],
+]
+
+# remove background, 1D Case
+TEST_CASE_4 = [ # y_pred (3, 1, 3), expected out (0.0)
+ {
+ "y_pred": torch.tensor(
+ [
+ [[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]],
+ [[4.0, 5.0, 6.0], [1.0, 1.0, 1.0]],
+ [[7.0, 8.0, 9.0], [1.0, 1.0, 1.0]],
+ ],
+ device=_device,
+ ),
+ "include_background": False,
+ "spatial_map": False,
+ },
+ [[0.0]],
+]
+
+# Spatial Map Test Case for 2D Case
+TEST_CASE_5 = [ # y_pred (1, 1, 2, 2), expected out all (0.0) map of 2x2
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "include_background": True,
+ "spatial_map": True,
+ },
+ [[0.0, 0.0], [0.0, 0.0]],
+]
+
+# Spatial Map Test Case for 3D Case
+TEST_CASE_6 = [ # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2
+ {
+ "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "include_background": True,
+ "spatial_map": True,
+ },
+ [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
+]
+
+# Threshold test for a 1D Case
+TEST_CASE_7 = [ # y_pred (3, 1, 3), expected out (0.0)
+ {
+ "y_pred": torch.tensor(
+ [
+ [[1.0, 2.0, 3.0], [1.0, 1.0, 0.0]],
+ [[4.0, 5.0, 6.0], [1.0, 1.0, 1.0]],
+ [[7.0, 8.0, 9.0], [1.0, 1.0, 0.0]],
+ [[1.0, 2.0, 3.0], [1.0, 1.0, 1.0]],
+ ],
+ device=_device,
+ ),
+ "include_background": False,
+ "spatial_map": False,
+ "threshold": 0.001,
+ },
+ [[0.083167]],
+]
+
+
+class TestComputeVariance(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
+ def test_value(self, input_data, expected_value):
+ result = compute_variance(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_5, TEST_CASE_6])
+ def test_spatial_case(self, input_data, expected_value):
+ result = compute_variance(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_7])
+ def test_threshold_case(self, input_data, expected_value):
+ result = compute_variance(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
+ def test_value_class(self, input_data, expected_value):
+ vals = {}
+ vals["y_pred"] = input_data.pop("y_pred")
+ comp_var = VarianceMetric(**input_data)
+ result = comp_var(**vals)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_5, TEST_CASE_6])
+ def test_spatial_case_class(self, input_data, expected_value):
+ vals = {}
+ vals["y_pred"] = input_data.pop("y_pred")
+ comp_var = VarianceMetric(**input_data)
+ result = comp_var(**vals)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py
index 068abf81c1e..c0a058a0ddd 100644
--- a/tests/test_concat_itemsd.py
+++ b/tests/test_concat_itemsd.py
@@ -16,6 +16,7 @@
from monai.data import MetaTensor
from monai.transforms import ConcatItemsd
+from tests.utils import assert_allclose
class TestConcatItemsd(unittest.TestCase):
@@ -28,8 +29,8 @@ def test_tensor_values(self):
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
self.assertTrue("cat_img" in result)
result["cat_img"] += 1
- torch.testing.assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
- torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
+ assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
+ assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
def test_metatensor_values(self):
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0")
@@ -42,8 +43,8 @@ def test_metatensor_values(self):
self.assertTrue(isinstance(result["cat_img"], MetaTensor))
self.assertEqual(result["img1"].meta, result["cat_img"].meta)
result["cat_img"] += 1
- torch.testing.assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
- torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
+ assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
+ assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
def test_numpy_values(self):
input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])}
@@ -64,15 +65,15 @@ def test_single_tensor(self):
input_data = {"img": torch.tensor([[0, 1], [1, 2]])}
result = ConcatItemsd(keys="img", name="cat_img")(input_data)
result["cat_img"] += 1
- torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
- torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
+ assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
+ assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
def test_single_metatensor(self):
input_data = {"img": MetaTensor([[0, 1], [1, 2]])}
result = ConcatItemsd(keys="img", name="cat_img")(input_data)
result["cat_img"] += 1
- torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
- torch.testing.assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
+ assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]]))
+ assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3]]))
if __name__ == "__main__":
diff --git a/tests/test_config_item.py b/tests/test_config_item.py
index 4d9df0a870e..817175e1e37 100644
--- a/tests/test_config_item.py
+++ b/tests/test_config_item.py
@@ -26,7 +26,7 @@
TEST_CASE_1 = [{"lr": 0.001}, 0.0001]
-TEST_CASE_2 = [{"_target_": "LoadImaged", "keys": ["image"]}, LoadImaged]
+TEST_CASE_2 = [{"_target_": "LoadImaged", "keys": ["image"], "_desc_": "an image reader for 'image'"}, LoadImaged]
# test full module path
TEST_CASE_3 = [{"_target_": "monai.transforms.LoadImaged", "keys": ["image"]}, LoadImaged]
# test `_disabled_`
diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py
index f991b3f5f56..d02a05c914f 100644
--- a/tests/test_config_parser.py
+++ b/tests/test_config_parser.py
@@ -22,9 +22,18 @@
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
+from tests.utils import TimedCall
_, has_tv = optional_import("torchvision", "0.8.0", min_version)
+
+@TimedCall(seconds=100, force_quit=True)
+def case_pdb(sarg=None):
+ config = {"transform": {"_target_": "Compose", "transforms": [], "_debug_": True}}
+ parser = ConfigParser(config=config)
+ parser.get_parsed_content()
+
+
# test the resolved and parsed instances
TEST_CASE_1 = [
{
@@ -78,7 +87,6 @@ def __call__(self, a, b):
}
]
-
TEST_CASE_3 = [
{
"A": 1,
@@ -123,9 +131,15 @@ def test_parse(self, config, expected_ids, output_types):
self.assertEqual(trans, parser.get_parsed_content(id="transform#transforms#0"))
self.assertEqual(trans, parser.get_parsed_content(id="transform#transforms#0", lazy=True))
self.assertNotEqual(trans, parser.get_parsed_content(id="transform#transforms#0", lazy=False))
- # test nested id
+ # test new nested id
+ parser.set("fake_key", "transform#other_transforms#keys", True)
+ self.assertEqual(parser.get(id="transform#other_transforms#keys"), "fake_key")
+ # remove temp fake data
+ parser["transform"].pop("other_transforms")
+ # test update nested id
parser["transform#transforms#0#keys"] = "label2"
self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label2")
+
for id, cls in zip(expected_ids, output_types):
self.assertTrue(isinstance(parser.get_parsed_content(id), cls))
# test root content
@@ -204,17 +218,23 @@ def test_contains(self):
empty_parser = ConfigParser({})
empty_parser.parse()
- parser = ConfigParser({"value": 1, "entry": "string content"})
+ parser = ConfigParser({"value": 1, "entry": "string content", "array": [1, 2]})
parser.parse()
with self.subTest("Testing empty parser"):
self.assertFalse("something" in empty_parser)
+ with self.assertRaises(KeyError):
+ empty_parser["something"]
+ empty_parser["osmething"] = "test"
+ with self.assertRaises(KeyError):
+ empty_parser["something"]
with self.subTest("Testing with keys"):
self.assertTrue("value" in parser)
self.assertFalse("value1" in parser)
self.assertTrue("entry" in parser)
self.assertFalse("entr" in parser)
+ self.assertFalse("array#2" in parser)
def test_lambda_reference(self):
configs = {
@@ -232,6 +252,10 @@ def test_error_instance(self):
with self.assertRaises(RuntimeError):
parser.get_parsed_content("transform", instantiate=True, eval_expr=True)
+ def test_pdb(self):
+ with self.assertRaisesRegex(RuntimeError, ".*bdb.BdbQuit.*"):
+ case_pdb()
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_contrastive_loss.py b/tests/test_contrastive_loss.py
index 5dce8604863..d0eb7d86f26 100644
--- a/tests/test_contrastive_loss.py
+++ b/tests/test_contrastive_loss.py
@@ -19,12 +19,12 @@
TEST_CASES = [
[ # shape: (1, 4), (1, 4)
- {"temperature": 0.5, "batch_size": 1},
+ {"temperature": 0.5},
{"input": torch.tensor([[1.0, 1.0, 0.0, 0.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},
0.0,
],
[ # shape: (2, 4), (2, 4)
- {"temperature": 0.5, "batch_size": 2},
+ {"temperature": 0.5},
{
"input": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
"target": torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
@@ -32,7 +32,7 @@
1.0986,
],
[ # shape: (1, 4), (1, 4)
- {"temperature": 0.5, "batch_size": 2},
+ {"temperature": 0.5},
{
"input": torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 1.0, 0.0, 0.0]]),
"target": torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]),
@@ -40,12 +40,12 @@
0.8719,
],
[ # shape: (1, 4), (1, 4)
- {"temperature": 0.5, "batch_size": 1},
+ {"temperature": 0.5},
{"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},
0.0,
],
[ # shape: (1, 4), (1, 4)
- {"temperature": 0.05, "batch_size": 1},
+ {"temperature": 0.05},
{"input": torch.tensor([[0.0, 0.0, 1.0, 1.0]]), "target": torch.tensor([[1.0, 1.0, 0.0, 0.0]])},
0.0,
],
@@ -60,12 +60,12 @@ def test_result(self, input_param, input_data, expected_val):
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
def test_ill_shape(self):
- loss = ContrastiveLoss(temperature=0.5, batch_size=1)
+ loss = ContrastiveLoss(temperature=0.5)
with self.assertRaisesRegex(ValueError, ""):
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
def test_with_cuda(self):
- loss = ContrastiveLoss(temperature=0.5, batch_size=1)
+ loss = ContrastiveLoss(temperature=0.5)
i = torch.ones((1, 10))
j = torch.ones((1, 10))
if torch.cuda.is_available():
@@ -74,6 +74,10 @@ def test_with_cuda(self):
output = loss(i, j)
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)
+ def check_warning_rasied(self):
+ with self.assertWarns(Warning):
+ ContrastiveLoss(temperature=0.5, batch_size=1)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py
index ab4bd3e3e6e..d411f3f9720 100644
--- a/tests/test_convert_data_type.py
+++ b/tests/test_convert_data_type.py
@@ -39,6 +39,8 @@
)
)
+UNSUPPORTED_TYPES = {np.dtype("uint16"): torch.int32, np.dtype("uint32"): torch.int64, np.dtype("uint64"): torch.int64}
+
class TestTensor(torch.Tensor):
pass
@@ -61,6 +63,13 @@ def test_convert_data_type(self, in_image, im_out):
def test_neg_stride(self):
_ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)
+ @parameterized.expand(list(UNSUPPORTED_TYPES.items()))
+ def test_unsupported_np_types(self, np_type, pt_type):
+ in_image = np.ones(13, dtype=np_type) # choose a prime size so as to be indivisible by the size of any dtype
+ converted_im, orig_type, orig_device = convert_data_type(in_image, torch.Tensor)
+
+ self.assertEqual(converted_im.dtype, pt_type)
+
@parameterized.expand(TESTS_LIST)
def test_convert_list(self, in_image, im_out, wrap):
output_type = type(im_out) if wrap else type(im_out[0])
diff --git a/tests/test_copy_itemsd.py b/tests/test_copy_itemsd.py
index 11f920cf6b2..8354f45bb5c 100644
--- a/tests/test_copy_itemsd.py
+++ b/tests/test_copy_itemsd.py
@@ -18,6 +18,7 @@
from monai.networks import eval_mode
from monai.transforms import CopyItemsd
from monai.utils import ensure_tuple
+from tests.utils import assert_allclose
TEST_CASE_1 = ["img", 1, "img_1"]
@@ -55,8 +56,8 @@ def test_tensor_values(self):
result = CopyItemsd(keys="img", names="img_1")(input_data)
self.assertTrue("img_1" in result)
result["img_1"] += 1
- torch.testing.assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]], device=device))
- torch.testing.assert_allclose(result["img_1"], torch.tensor([[1, 2], [2, 3]], device=device))
+ assert_allclose(result["img"], torch.tensor([[0, 1], [1, 2]], device=device))
+ assert_allclose(result["img_1"], torch.tensor([[1, 2], [2, 3]], device=device))
def test_array_values(self):
input_data = {"img": [[0, 1], [1, 2]], "seg": [[0, 1], [1, 2]]}
@@ -75,8 +76,8 @@ def test_graph_tensor_values(self):
result = CopyItemsd(keys="pred", times=1, names="pred_1")(input_data)
self.assertTrue("pred_1" in result)
result["pred_1"] += 1.0
- torch.testing.assert_allclose(result["pred"], torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device))
- torch.testing.assert_allclose(result["pred_1"], torch.tensor([[1.0, 2.0], [2.0, 3.0]], device=device))
+ assert_allclose(result["pred"], torch.tensor([[0.0, 1.0], [1.0, 2.0]], device=device))
+ assert_allclose(result["pred_1"], torch.tensor([[1.0, 2.0], [2.0, 3.0]], device=device))
if __name__ == "__main__":
diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py
index f8b54c31478..4a6d2f9d51c 100644
--- a/tests/test_cucim_dict_transform.py
+++ b/tests/test_cucim_dict_transform.py
@@ -41,7 +41,6 @@
np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),
]
-
TEST_CASE_ROTATE_1 = [
{"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)},
np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),
diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py
index 2bf9791bce1..dd73ad94c06 100644
--- a/tests/test_cucim_transform.py
+++ b/tests/test_cucim_transform.py
@@ -41,7 +41,6 @@
np.array([[[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]], [[1.0, 0.0], [3.0, 2.0]]], dtype=np.float32),
]
-
TEST_CASE_ROTATE_1 = [
{"name": "image_rotate_90", "k": 1, "spatial_axis": (-2, -1)},
np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),
diff --git a/tests/test_cumulative_average.py b/tests/test_cumulative_average.py
index 4e7e4ff5d99..543433a6d3a 100644
--- a/tests/test_cumulative_average.py
+++ b/tests/test_cumulative_average.py
@@ -17,46 +17,55 @@
from monai.metrics import CumulativeAverage
-# single class value
-TEST_CASE_1 = [[torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])], torch.as_tensor([0.2])]
-
-# multi-class value
-TEST_CASE_2 = [
- [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])],
- torch.as_tensor([0.2, 0.3]),
-]
-
-# Nan value
-TEST_CASE_3 = [
- [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])],
- torch.as_tensor([0.15]),
-]
-
-# different input shape
-TEST_CASE_4 = [[torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)], torch.as_tensor(0.2)]
-
-
-class TestCumulativeAverage(unittest.TestCase):
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
- def test_value(self, input_data, expected_value):
- average = CumulativeAverage()
- func = average.append if input_data[0].ndim < 2 else average.extend
- func(input_data[0])
- func(input_data[1])
- result = average.aggregate()
- # continue to update new data
- func(input_data[2])
- result = average.aggregate()
- torch.testing.assert_allclose(result, expected_value)
-
- def test_numpy_array(self):
- class TestCumulativeAverage(CumulativeAverage):
- def get_buffer(self):
- return np.array([[1, 2], [3, np.nan]])
-
- average = TestCumulativeAverage()
- result = average.aggregate()
- np.testing.assert_allclose(result, np.array([2.0, 2.0]))
+TEST_CASE_1 = []
+TEST_CASE_1.append([{"vals": [1, 2, 3], "avg": 2}])
+TEST_CASE_1.append([{"vals": [[1, 1, 1], [2, 2, 2], [3, 6, 9]], "avg": [2, 3, 4]}])
+
+TEST_CASE_1.append([{"vals": [2, 4, 6], "counts": [2, 1, 2], "avg": 4}])
+TEST_CASE_1.append(
+ [{"vals": [[3, 2, 1], [2, 3, 2], [0, 0, 9]], "counts": [[4, 4, 4], [4, 4, 4], [2, 2, 2]], "avg": [2, 2, 3]}]
+)
+
+TEST_CASE_1.append([{"vals": [1, 2, float("nan")], "avg": 1.5}])
+
+
+class TestAverageMeter(unittest.TestCase):
+ @parameterized.expand(TEST_CASE_1)
+ def test_value_all(self, data):
+
+ # test orig
+ self.run_test(data)
+
+ # test in numpy
+ data["vals"] = np.array(data["vals"])
+ data["avg"] = np.array(data["avg"])
+ self.run_test(data)
+
+ # test as Tensors
+ data["vals"] = torch.tensor(data["vals"])
+ data["avg"] = torch.tensor(data["avg"], dtype=torch.float)
+ self.run_test(data)
+
+ if torch.cuda.is_available():
+ data["vals"] = data["vals"].cuda()
+ self.run_test(data)
+
+ def run_test(self, data):
+ vals = data["vals"]
+ avg = data["avg"]
+
+ counts = data.get("counts", None)
+ if counts is not None and not isinstance(counts, list) and isinstance(vals, list):
+ counts = [counts] * len(vals)
+
+ avg_meter = CumulativeAverage()
+ for i in range(len(vals)):
+ if counts is not None:
+ avg_meter.append(vals[i], counts[i])
+ else:
+ avg_meter.append(vals[i])
+
+ np.testing.assert_equal(avg_meter.aggregate(), avg)
if __name__ == "__main__":
diff --git a/tests/test_cumulative_average_dist.py b/tests/test_cumulative_average_dist.py
index 5de139e9ac6..a5ee2fed15e 100644
--- a/tests/test_cumulative_average_dist.py
+++ b/tests/test_cumulative_average_dist.py
@@ -11,36 +11,36 @@
import unittest
+import numpy as np
import torch
import torch.distributed as dist
from monai.metrics import CumulativeAverage
-from tests.utils import DistCall, DistTestCase
+from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion
+@SkipIfBeforePyTorchVersion((1, 8))
class DistributedCumulativeAverage(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_value(self):
+
rank = dist.get_rank()
- input_data = [
- [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[0.3]])],
- [torch.as_tensor([[0.1]]), torch.as_tensor([[0.2]]), torch.as_tensor([[float("nan")]])],
- [torch.as_tensor([[0.1, 0.2]]), torch.as_tensor([[0.2, 0.3]]), torch.as_tensor([[0.3, 0.4]])],
- [torch.as_tensor(0.1), torch.as_tensor(0.2), torch.as_tensor(0.3)],
- ]
- expected = [torch.as_tensor([0.2]), torch.as_tensor([0.15]), torch.as_tensor([0.2, 0.3]), torch.as_tensor(0.2)]
- average = CumulativeAverage()
-
- for i, e in zip(input_data, expected):
- func = average.append if i[0].ndim < 2 else average.extend
- if rank == 0:
- func(i[0])
- func(i[1])
- else:
- func(i[2])
- result = average.aggregate()
- torch.testing.assert_allclose(result, e)
- average.reset()
+ nprocs = dist.get_world_size()
+ is_cuda = dist.get_backend() == dist.Backend.NCCL
+ if is_cuda:
+ torch.cuda.set_device(rank)
+
+ device = torch.device(rank) if is_cuda else torch.device("cpu")
+
+ avg_meter = CumulativeAverage() # each process rank has it's own AverageMeter
+ n_iter = 10
+ for i in range(n_iter):
+ val = torch.as_tensor(rank + i, device=device)
+ avg_meter.append(val=val)
+
+ avg_val = avg_meter.aggregate() # average across all processes
+ expected_val = sum(sum(list(range(rank_i, rank_i + n_iter))) for rank_i in range(nprocs)) / (n_iter * nprocs)
+ np.testing.assert_equal(avg_val, expected_val)
if __name__ == "__main__":
diff --git a/tests/test_cv2_dist.py b/tests/test_cv2_dist.py
new file mode 100644
index 00000000000..cf4c77cfe43
--- /dev/null
+++ b/tests/test_cv2_dist.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+import torch.distributed as dist
+from torch.cuda.amp import autocast
+
+# FIXME: test for the workaround of https://github.com/Project-MONAI/MONAI/issues/5291
+from monai.config.deviceconfig import print_config
+from tests.utils import skip_if_no_cuda
+
+
+def main_worker(rank, ngpus_per_node):
+ dist.init_process_group(backend="nccl", init_method="tcp://127.0.0.1:12345", world_size=ngpus_per_node, rank=rank)
+ # `benchmark = True` is not compatible with openCV in PyTorch 22.09 docker for multi-gpu training
+ torch.backends.cudnn.benchmark = True
+
+ model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[rank], output_device=rank, find_unused_parameters=False
+ )
+ x = torch.ones(1, 1, 12, 12, 12).to(rank)
+ with autocast(enabled=True):
+ model(x)
+
+
+@skip_if_no_cuda
+class TestCV2Dist(unittest.TestCase):
+ def test_cv2_cuda_ops(self):
+ print_config()
+ ngpus_per_node = torch.cuda.device_count()
+ torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py
index 79126b2dbb5..b75c2a4ed82 100644
--- a/tests/test_dataloader.py
+++ b/tests/test_dataloader.py
@@ -16,9 +16,10 @@
import torch
from parameterized import parameterized
-from monai.data import CacheDataset, DataLoader, Dataset
+from monai.data import CacheDataset, DataLoader, Dataset, ZipDataset
from monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd
-from monai.utils import set_determinism
+from monai.utils import convert_to_numpy, set_determinism
+from tests.utils import assert_allclose
TEST_CASE_1 = [[{"image": np.asarray([1, 2, 3])}, {"image": np.asarray([4, 5])}]]
@@ -74,14 +75,26 @@ def setUp(self):
def tearDown(self):
set_determinism(None)
- def test_randomize(self):
+ @parameterized.expand([[1], [0]])
+ def test_randomize(self, workers):
+ set_determinism(0)
dataset = _RandomDataset()
- dataloader = DataLoader(dataset, batch_size=2, num_workers=3)
+ dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=workers)
output = []
- for _ in range(2):
+ for _ in range(1): # need persistent workers for reproducibility of num_workers 0, 1
for batch in dataloader:
output.extend(batch.data.numpy().flatten().tolist())
- self.assertListEqual(output, [594, 170, 524, 778, 370, 906, 292, 589, 762, 763, 156, 886, 42, 405, 221, 166])
+ set_determinism(None)
+ self.assertListEqual(output, [594, 170, 292, 589, 153, 811, 21, 550])
+
+ def test_zipdataset(self):
+ dataset = ZipDataset([_RandomDataset(), ZipDataset([_RandomDataset(), _RandomDataset()])])
+ dataloader = DataLoader(dataset, batch_size=2, num_workers=2)
+ output = []
+ for _ in range(2):
+ for batch in dataloader:
+ output.extend([convert_to_numpy(batch, wrap_sequence=False)])
+ assert_allclose(np.stack(output).flatten()[:7], np.array([594, 170, 594, 170, 594, 170, 524]))
if __name__ == "__main__":
diff --git a/tests/test_decollate.py b/tests/test_decollate.py
index a634471be55..538eb38311a 100644
--- a/tests/test_decollate.py
+++ b/tests/test_decollate.py
@@ -55,7 +55,6 @@
TESTS_LIST.append((RandRotate90(prob=0.0, max_k=1),))
TESTS_LIST.append((RandAffine(prob=0.0, translate_range=10),))
-
TEST_BASIC = [
[("channel", "channel"), ["channel", "channel"]],
[torch.Tensor([1, 2, 3]), [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)]],
diff --git a/tests/test_deepedit_interaction.py b/tests/test_deepedit_interaction.py
index 6bb723268f2..2bcace59fd3 100644
--- a/tests/test_deepedit_interaction.py
+++ b/tests/test_deepedit_interaction.py
@@ -23,7 +23,7 @@
FindDiscrepancyRegionsDeepEditd,
SplitPredsLabeld,
)
-from monai.data import Dataset
+from monai.data import DataLoader, Dataset
from monai.engines import SupervisedTrainer
from monai.engines.utils import IterationEvents
from monai.losses import DiceCELoss
@@ -62,7 +62,7 @@ def run_interaction(self, train):
]
)
dataset = Dataset(data, transform=pre_transforms)
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)
+ data_loader = DataLoader(dataset, batch_size=5)
iteration_transforms = [
FindDiscrepancyRegionsDeepEditd(keys="label", pred="pred", discrepancy="discrepancy"),
diff --git a/tests/test_deepedit_transforms.py b/tests/test_deepedit_transforms.py
index 225b2fc60b1..f608a4342fc 100644
--- a/tests/test_deepedit_transforms.py
+++ b/tests/test_deepedit_transforms.py
@@ -140,7 +140,6 @@
DATA_11 = {"image": IMAGE, "label": LABEL, "label_names": LABEL_NAMES, "pred": PRED}
-
ADD_GUIDANCE_FROM_POINTS_TEST_CASE = [
{"ref_image": "image", "guidance": "guidance", "label_names": LABEL_NAMES}, # arguments
DATA_4, # input_data
@@ -165,7 +164,7 @@
ADD_RANDOM_GUIDANCE_TEST_CASE = [
{"keys": "NA", "guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"}, # arguments
DATA_2, # input_data
- {"spleen": [[3, 5, 4, 6], [-1, -1, -1, -1]], "background": [[-1, -1, -1, -1], [-1, -1, -1, -1]]}, # expected_result
+ 0, # expected_result
]
DISCARD_ADD_GUIDANCE_TEST_CASE = [
@@ -238,7 +237,8 @@ class TestAddRandomGuidanceCustomd(unittest.TestCase):
def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddRandomGuidanceDeepEditd(**arguments)
result = add_fn(input_data)
- self.assertEqual(result[arguments["guidance"]], expected_result)
+ label_key = list(result[arguments["guidance"]].keys())[0]
+ self.assertGreaterEqual(len(result[arguments["guidance"]][label_key]), expected_result)
class TestDiscardAddGuidanced(unittest.TestCase):
diff --git a/tests/test_deepgrow_dataset.py b/tests/test_deepgrow_dataset.py
index ff8de87b81c..5b3e40b1ee9 100644
--- a/tests/test_deepgrow_dataset.py
+++ b/tests/test_deepgrow_dataset.py
@@ -29,20 +29,20 @@
TEST_CASE_4 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1}, 1, 1]
-TEST_CASE_5 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 4}, 1, 1]
+TEST_CASE_5 = [{"dimension": 3, "pixdim": (1, 1, 1)}, {"length": 1, "image_channel": 1}, 1, 1]
-TEST_CASE_6 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 4}, 3, 1]
+TEST_CASE_6 = [{"dimension": 2, "pixdim": (1, 1)}, {"length": 1, "image_channel": 1}, 3, 1]
TEST_CASE_7 = [
{"dimension": 2, "pixdim": (1, 1), "label_key": None},
- {"length": 1, "image_channel": 4, "with_label": False},
+ {"length": 1, "image_channel": 1, "with_label": False},
40,
None,
]
TEST_CASE_8 = [
{"dimension": 3, "pixdim": (1, 1, 1), "label_key": None},
- {"length": 1, "image_channel": 4, "with_label": False},
+ {"length": 1, "image_channel": 1, "with_label": False},
1,
None,
]
diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py
index 436bef0c5ba..bd20b45b6d3 100644
--- a/tests/test_deepgrow_transforms.py
+++ b/tests/test_deepgrow_transforms.py
@@ -226,42 +226,42 @@
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [
- {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True},
+ {"ref_image": "image", "spatial_dims": 3, "guidance": "guidance", "depth_first": True},
DATA_5,
[[0, 2, 2]],
[],
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_2 = [
- {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True},
+ {"ref_image": "image", "spatial_dims": 3, "guidance": "guidance", "depth_first": True},
DATA_6,
[[0, 2, 2]],
[[0, 1, 0]],
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_3 = [
- {"ref_image": "image", "dimensions": 3, "guidance": "guidance", "depth_first": True},
+ {"ref_image": "image", "spatial_dims": 3, "guidance": "guidance", "depth_first": True},
DATA_7,
[[3, 5, 7], [4, 5, 7]],
[[4, 5, 8]],
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_4 = [
- {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True},
+ {"ref_image": "image", "spatial_dims": 2, "guidance": "guidance", "depth_first": True},
DATA_6,
[[2, 2]],
[[1, 0]],
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_5 = [
- {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True, "slice_key": "slice"},
+ {"ref_image": "image", "spatial_dims": 2, "guidance": "guidance", "depth_first": True, "slice_key": "slice"},
DATA_7,
[[5, 7]],
[],
]
ADD_GUIDANCE_FROM_POINTS_TEST_CASE_6 = [
- {"ref_image": "image", "dimensions": 2, "guidance": "guidance", "depth_first": True},
+ {"ref_image": "image", "spatial_dims": 2, "guidance": "guidance", "depth_first": True},
DATA_5,
[[2, 2]],
[],
diff --git a/tests/test_denseblock.py b/tests/test_denseblock.py
new file mode 100644
index 00000000000..dd0a1030228
--- /dev/null
+++ b/tests/test_denseblock.py
@@ -0,0 +1,103 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch.nn as nn
+
+from monai.networks.blocks import ConvDenseBlock, DenseBlock
+from tests.utils import TorchImageTestCase2D, TorchImageTestCase3D
+
+
+class TestDenseBlock2D(TorchImageTestCase2D):
+ def test_block_empty(self):
+ block = DenseBlock([])
+ out = block(self.imt)
+ expected_shape = self.imt.shape
+ self.assertEqual(out.shape, expected_shape)
+
+ def test_block_conv(self):
+ conv1 = nn.Conv2d(self.input_channels, self.output_channels, 3, padding=1)
+ conv2 = nn.Conv2d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1)
+ block = DenseBlock([conv1, conv2])
+ out = block(self.imt)
+ expected_shape = (1, self.output_channels + self.input_channels * 2, self.im_shape[0], self.im_shape[1])
+ self.assertEqual(out.shape, expected_shape)
+
+
+class TestDenseBlock3D(TorchImageTestCase3D):
+ def test_block_conv(self):
+ conv1 = nn.Conv3d(self.input_channels, self.output_channels, 3, padding=1)
+ conv2 = nn.Conv3d(self.input_channels + self.output_channels, self.input_channels, 3, padding=1)
+ block = DenseBlock([conv1, conv2])
+ out = block(self.imt)
+ expected_shape = (
+ 1,
+ self.output_channels + self.input_channels * 2,
+ self.im_shape[1],
+ self.im_shape[0],
+ self.im_shape[2],
+ )
+ self.assertEqual(out.shape, expected_shape)
+
+
+class TestConvDenseBlock2D(TorchImageTestCase2D):
+ def test_block_empty(self):
+ conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[])
+ out = conv(self.imt)
+ expected_shape = self.imt.shape
+ self.assertEqual(out.shape, expected_shape)
+
+ def test_except(self):
+ with self.assertRaises(ValueError):
+ _ = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=[1, 2], dilations=[1, 2, 3])
+
+ def test_block1(self):
+ channels = [2, 4]
+ conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=channels)
+ out = conv(self.imt)
+ expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1])
+ self.assertEqual(out.shape, expected_shape)
+
+ def test_block2(self):
+ channels = [2, 4]
+ dilations = [1, 2]
+ conv = ConvDenseBlock(spatial_dims=2, in_channels=self.input_channels, channels=channels, dilations=dilations)
+ out = conv(self.imt)
+ expected_shape = (1, self.input_channels + sum(channels), self.im_shape[0], self.im_shape[1])
+ self.assertEqual(out.shape, expected_shape)
+
+
+class TestConvDenseBlock3D(TorchImageTestCase3D):
+ def test_block_empty(self):
+ conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=[])
+ out = conv(self.imt)
+ expected_shape = self.imt.shape
+ self.assertEqual(out.shape, expected_shape)
+
+ def test_block1(self):
+ channels = [2, 4]
+ conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels)
+ out = conv(self.imt)
+ expected_shape = (1, self.input_channels + sum(channels), self.im_shape[1], self.im_shape[0], self.im_shape[2])
+ self.assertEqual(out.shape, expected_shape)
+
+ def test_block2(self):
+ channels = [2, 4]
+ dilations = [1, 2]
+ conv = ConvDenseBlock(spatial_dims=3, in_channels=self.input_channels, channels=channels, dilations=dilations)
+ out = conv(self.imt)
+ expected_shape = (1, self.input_channels + sum(channels), self.im_shape[1], self.im_shape[0], self.im_shape[2])
+ self.assertEqual(out.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_densenet.py b/tests/test_densenet.py
index 47f584297ec..66f27cba51a 100644
--- a/tests/test_densenet.py
+++ b/tests/test_densenet.py
@@ -28,7 +28,6 @@
else:
torchvision, has_torchvision = optional_import("torchvision")
-
device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_CASE_1 = [ # 4-channel 3D, batch 2
@@ -54,10 +53,8 @@
for model in [DenseNet121, Densenet169, densenet201, DenseNet264]:
TEST_CASES.append([model, *case])
-
TEST_SCRIPT_CASES = [[model, *TEST_CASE_1] for model in [DenseNet121, Densenet169, densenet201, DenseNet264]]
-
TEST_PRETRAINED_2D_CASE_1 = [ # 4-channel 2D, batch 2
DenseNet121,
{"pretrained": True, "progress": True, "spatial_dims": 2, "in_channels": 2, "out_channels": 3},
diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py
index 3d27994404c..c94c3001756 100644
--- a/tests/test_deprecated.py
+++ b/tests/test_deprecated.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import warnings
diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py
index 83ad5b8d9af..1f43dd8c9aa 100644
--- a/tests/test_dice_ce_loss.py
+++ b/tests/test_dice_ce_loss.py
@@ -35,6 +35,14 @@
},
0.3133,
],
+ [ # shape: (2, 2, 3), (2, 2, 3), one-hot target
+ {"to_onehot_y": False},
+ {
+ "input": torch.tensor([[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]),
+ "target": torch.tensor([[[1, 1, 0], [0, 0, 1]], [[1, 0, 1], [0, 1, 0]]], dtype=torch.uint8),
+ },
+ 0.3133,
+ ],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "ce_weight": torch.tensor([1.0, 1.0])},
{
diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py
index b77a36e7209..af3e868654c 100644
--- a/tests/test_dice_focal_loss.py
+++ b/tests/test_dice_focal_loss.py
@@ -13,6 +13,7 @@
import numpy as np
import torch
+from parameterized import parameterized
from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss
from tests.utils import test_script_save
@@ -36,17 +37,24 @@ def test_result_onehot_target_include_bg(self):
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
np.testing.assert_allclose(result, expected_val)
- def test_result_no_onehot_no_bg(self):
- size = [3, 3, 5, 5]
- label = torch.randint(low=0, high=2, size=size)
- label = torch.argmax(label, dim=1, keepdim=True)
+ @parameterized.expand([[[3, 3, 5, 5], True], [[3, 2, 5, 5], False]])
+ def test_result_no_onehot_no_bg(self, size, onehot):
+ label = torch.randint(low=0, high=size[1] - 1, size=size)
+ if onehot:
+ label = torch.argmax(label, dim=1, keepdim=True)
pred = torch.randn(size)
for reduction in ["sum", "mean", "none"]:
- common_params = {"include_background": False, "to_onehot_y": True, "reduction": reduction}
- for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:
+ for focal_weight in [2.0] + [] if size[1] != 3 else [torch.tensor([1.0, 2.0]), (2.0, 1)]:
for lambda_focal in [0.5, 1.0, 1.5]:
+ common_params = {
+ "include_background": False,
+ "softmax": True,
+ "to_onehot_y": onehot,
+ "reduction": reduction,
+ }
dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params)
dice = DiceLoss(**common_params)
+ common_params.pop("softmax", None)
focal = FocalLoss(weight=focal_weight, **common_params)
result = dice_focal(pred, label)
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
diff --git a/tests/test_ds_loss.py b/tests/test_ds_loss.py
new file mode 100644
index 00000000000..dc67b651a35
--- /dev/null
+++ b/tests/test_ds_loss.py
@@ -0,0 +1,187 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.losses import DeepSupervisionLoss, DiceCELoss, DiceFocalLoss, DiceLoss
+from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
+
+TEST_CASES_DICECE = [
+ [
+ {"to_onehot_y": True},
+ {},
+ {
+ "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 0.606557,
+ ]
+]
+
+TEST_CASES_DICECE2 = [
+ [
+ {"to_onehot_y": True},
+ {},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 1.78144,
+ ],
+ [
+ {"to_onehot_y": True},
+ {"weight_mode": "same"},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 3.5529,
+ ],
+ [
+ {"to_onehot_y": True},
+ {"weight_mode": "two"},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 2.07973,
+ ],
+ [
+ {"to_onehot_y": True},
+ {"weights": [0.1, 0.2, 0.3]},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 0.76924,
+ ],
+]
+
+TEST_CASES_DICE = [
+ [
+ {"to_onehot_y": True},
+ {
+ "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 0.166666, # the result equals to -1 + np.log(1 + np.exp(1))
+ ],
+ [
+ {"to_onehot_y": True},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 0.666665,
+ ],
+]
+
+TEST_CASES_DICEFOCAL = [
+ [
+ {"to_onehot_y": True},
+ {
+ "input": torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 0.32124, # the result equals to -1 + np.log(1 + np.exp(1))
+ ],
+ [
+ {"to_onehot_y": True},
+ {
+ "input": [
+ torch.tensor([[[[1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], [[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ torch.tensor([[[[1.0, 1.0], [0.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]]]]),
+ torch.tensor([[[[1.0], [0.0]], [[1.0], [0.0]]]]),
+ ],
+ "target": torch.tensor([[[[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]]]),
+ },
+ 1.06452,
+ ],
+]
+
+
+class TestDSLossDiceCE(unittest.TestCase):
+ @parameterized.expand(TEST_CASES_DICECE)
+ def test_result(self, input_param, input_param2, input_data, expected_val):
+ diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
+ result = diceceloss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+ def test_ill_shape(self):
+ loss = DeepSupervisionLoss(DiceCELoss())
+ with self.assertRaisesRegex(ValueError, ""):
+ loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ def test_ill_reduction(self):
+ with self.assertRaisesRegex(ValueError, ""):
+ loss = DeepSupervisionLoss(DiceCELoss(reduction="none"))
+ loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ @SkipIfBeforePyTorchVersion((1, 10))
+ def test_script(self):
+ loss = DeepSupervisionLoss(DiceCELoss())
+ test_input = torch.ones(2, 1, 8, 8)
+ test_script_save(loss, test_input, test_input)
+
+
+@SkipIfBeforePyTorchVersion((1, 11))
+class TestDSLossDiceCE2(unittest.TestCase):
+ @parameterized.expand(TEST_CASES_DICECE2)
+ def test_result(self, input_param, input_param2, input_data, expected_val):
+ diceceloss = DeepSupervisionLoss(DiceCELoss(**input_param), **input_param2)
+ result = diceceloss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+
+@SkipIfBeforePyTorchVersion((1, 11))
+class TestDSLossDice(unittest.TestCase):
+ @parameterized.expand(TEST_CASES_DICE)
+ def test_result(self, input_param, input_data, expected_val):
+ loss = DeepSupervisionLoss(DiceLoss(**input_param))
+ result = loss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+
+@SkipIfBeforePyTorchVersion((1, 11))
+class TestDSLossDiceFocal(unittest.TestCase):
+ @parameterized.expand(TEST_CASES_DICEFOCAL)
+ def test_result(self, input_param, input_data, expected_val):
+ loss = DeepSupervisionLoss(DiceFocalLoss(**input_param))
+ result = loss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py
index dab46f366f6..15ea6a09528 100644
--- a/tests/test_ensemble_evaluator.py
+++ b/tests/test_ensemble_evaluator.py
@@ -16,6 +16,7 @@
from parameterized import parameterized
from monai.engines import EnsembleEvaluator
+from tests.utils import assert_allclose
TEST_CASE_1 = [["pred_0", "pred_1", "pred_2", "pred_3", "pred_4"]]
@@ -67,7 +68,7 @@ class CustomEvents(EventEnum):
def run_transform(engine):
for i in range(5):
expected_value = engine.state.iteration + i
- torch.testing.assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value)
+ assert_allclose(engine.state.output[0][f"pred_{i}"].item(), expected_value)
@val_engine.on(Events.EPOCH_COMPLETED)
def trigger_custom_event():
diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py
index 1cb5ac6dec3..d8dba562bb4 100644
--- a/tests/test_ensure_channel_first.py
+++ b/tests/test_ensure_channel_first.py
@@ -13,16 +13,18 @@
import tempfile
import unittest
-import itk
import nibabel as nib
import numpy as np
import torch
from parameterized import parameterized
from PIL import Image
-from monai.data import ITKReader
from monai.data.meta_tensor import MetaTensor
from monai.transforms import EnsureChannelFirst, LoadImage
+from monai.utils import optional_import
+
+itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
+ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
TEST_CASE_1 = [{}, ["test_image.nii.gz"], None]
@@ -30,22 +32,25 @@
TEST_CASE_3 = [{}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None]
-TEST_CASE_4 = [{"reader": ITKReader()}, ["test_image.nii.gz"], None]
-
-TEST_CASE_5 = [{"reader": ITKReader()}, ["test_image.nii.gz"], -1]
+TEST_CASE_4 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], None]
-TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"], None]
+TEST_CASE_5 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], -1]
-TEST_CASE_7 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", None]
+TEST_CASE_6 = [
+ {"reader": ITKReader() if has_itk else "itkreader"},
+ ["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
+ None,
+]
class TestEnsureChannelFirst(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
+ @unittest.skipUnless(has_itk, "itk not installed")
def test_load_nifti(self, input_param, filenames, original_channel_dim):
if original_channel_dim is None:
- test_image = np.random.rand(128, 128, 128)
+ test_image = np.random.rand(8, 8, 8)
elif original_channel_dim == -1:
- test_image = np.random.rand(128, 128, 128, 1)
+ test_image = np.random.rand(8, 8, 8, 1)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
@@ -56,32 +61,48 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim):
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], len(filenames))
- @parameterized.expand([TEST_CASE_7])
- def test_itk_dicom_series_reader(self, input_param, filenames, _):
- result = LoadImage(image_only=True, **input_param)(filenames)
+ @unittest.skipUnless(has_itk, "itk not installed")
+ def test_itk_dicom_series_reader(self):
+ filenames = "tests/testing_data/CT_DICOM"
+ itk.ProcessObject.SetGlobalWarningDisplay(False)
+ result = LoadImage(image_only=True, reader=ITKReader(pixel_type=itk.UC))(filenames)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], 1)
def test_load_png(self):
- spatial_size = (256, 256, 3)
- test_image = np.random.randint(0, 256, size=spatial_size)
+ spatial_size = (6, 6, 3)
+ test_image = np.random.randint(0, 6, size=spatial_size)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_image.png")
Image.fromarray(test_image.astype("uint8")).save(filename)
result = LoadImage(image_only=True)(filename)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], 3)
+ result = EnsureChannelFirst(channel_dim=-1)(result)
+ self.assertEqual(result.shape, (6, 3, 6))
def test_check(self):
im = torch.zeros(1, 2, 3)
+ im_nodim = MetaTensor(im, meta={"original_channel_dim": None})
+
with self.assertRaises(ValueError): # not MetaTensor
- EnsureChannelFirst()(im)
+ EnsureChannelFirst(channel_dim=None)(im)
with self.assertRaises(ValueError): # no meta
- EnsureChannelFirst()(MetaTensor(im))
+ EnsureChannelFirst(channel_dim=None)(MetaTensor(im))
with self.assertRaises(ValueError): # no meta channel
- EnsureChannelFirst()(MetaTensor(im, meta={"original_channel_dim": None}))
- EnsureChannelFirst(strict_check=False)(im)
- EnsureChannelFirst(strict_check=False)(MetaTensor(im, meta={"original_channel_dim": None}))
+ EnsureChannelFirst()(im_nodim)
+
+ with self.assertWarns(Warning):
+ EnsureChannelFirst(strict_check=False, channel_dim=None)(im)
+
+ with self.assertWarns(Warning):
+ EnsureChannelFirst(strict_check=False, channel_dim=None)(im_nodim)
+
+ def test_default_channel_first(self):
+ im = torch.rand(4, 4)
+ result = EnsureChannelFirst(channel_dim="no_channel")(im)
+
+ self.assertEqual(result.shape, (1, 4, 4))
if __name__ == "__main__":
diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py
index 8525939f596..44bb7e40f49 100644
--- a/tests/test_ensure_channel_firstd.py
+++ b/tests/test_ensure_channel_firstd.py
@@ -33,9 +33,9 @@ class TestEnsureChannelFirstd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_load_nifti(self, input_param, filenames, original_channel_dim):
if original_channel_dim is None:
- test_image = np.random.rand(128, 128, 128)
+ test_image = np.random.rand(8, 8, 8)
elif original_channel_dim == -1:
- test_image = np.random.rand(128, 128, 128, 1)
+ test_image = np.random.rand(8, 8, 8, 1)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
@@ -46,7 +46,7 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim):
self.assertEqual(result["img"].shape[0], len(filenames))
def test_load_png(self):
- spatial_size = (256, 256, 3)
+ spatial_size = (6, 6, 3)
test_image = np.random.randint(0, 256, size=spatial_size)
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_image.png")
@@ -57,12 +57,24 @@ def test_load_png(self):
def test_exceptions(self):
im = torch.zeros((1, 2, 3))
+ im_nodim = MetaTensor(im, meta={"original_channel_dim": None})
+
with self.assertRaises(ValueError): # no meta
- EnsureChannelFirstd("img")({"img": im})
+ EnsureChannelFirstd("img", channel_dim=None)({"img": im})
with self.assertRaises(ValueError): # no meta channel
- EnsureChannelFirstd("img")({"img": MetaTensor(im, meta={"original_channel_dim": None})})
- EnsureChannelFirstd("img", strict_check=False)({"img": im})
- EnsureChannelFirstd("img", strict_check=False)({"img": MetaTensor(im, meta={"original_channel_dim": None})})
+ EnsureChannelFirstd("img", channel_dim=None)({"img": im_nodim})
+
+ with self.assertWarns(Warning):
+ EnsureChannelFirstd("img", strict_check=False, channel_dim=None)({"img": im})
+
+ with self.assertWarns(Warning):
+ EnsureChannelFirstd("img", strict_check=False, channel_dim=None)({"img": im_nodim})
+
+ def test_default_channel_first(self):
+ im = torch.rand(4, 4)
+ result = EnsureChannelFirstd("img", channel_dim="no_channel")({"img": im})
+
+ self.assertEqual(result["img"].shape, (1, 4, 4))
if __name__ == "__main__":
diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py
index 55423838b80..9325e0b601f 100644
--- a/tests/test_ensure_type.py
+++ b/tests/test_ensure_type.py
@@ -63,12 +63,12 @@ def test_list_tuple(self):
result = EnsureType(data_type=dtype, wrap_sequence=False, track_meta=True)([[1, 2], [3, 4]])
self.assertTrue(isinstance(result, list))
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
+ assert_allclose(result[1][0], torch.as_tensor(3), type_test=False)
# tuple of numpy arrays
result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4])))
self.assertTrue(isinstance(result, tuple))
self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4]))
+ assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False)
def test_dict(self):
# simulate complicated input data
@@ -81,9 +81,9 @@ def test_dict(self):
result = EnsureType(data_type=dtype, track_meta=False)(test_data)
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))
+ assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]))
+ assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
self.assertEqual(result["meta"]["path"], "temp/test")
self.assertEqual(result["extra"], None)
diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py
index d57170e2a6b..789afd1a465 100644
--- a/tests/test_ensure_typed.py
+++ b/tests/test_ensure_typed.py
@@ -67,14 +67,14 @@ def test_list_tuple(self):
)["data"]
self.assertTrue(isinstance(result, list))
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
+ assert_allclose(result[1][0], torch.as_tensor(3), type_test=False)
# tuple of numpy arrays
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)(
{"data": (np.array([1, 2]), np.array([3, 4]))}
)["data"]
self.assertTrue(isinstance(result, tuple))
self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result[1], torch.as_tensor([3, 4]))
+ assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False)
def test_dict(self):
# simulate complicated input data
@@ -87,9 +87,9 @@ def test_dict(self):
result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"]
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))
+ assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
- torch.testing.assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]))
+ assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
self.assertEqual(result["meta"]["path"], "temp/test")
self.assertEqual(result["extra"], None)
diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py
index 1bb3d887a04..5e4b0b3b5d0 100644
--- a/tests/test_evenly_divisible_all_gather_dist.py
+++ b/tests/test_evenly_divisible_all_gather_dist.py
@@ -15,7 +15,7 @@
import torch.distributed as dist
from monai.utils import evenly_divisible_all_gather
-from tests.utils import DistCall, DistTestCase
+from tests.utils import DistCall, DistTestCase, assert_allclose
class DistributedEvenlyDivisibleAllGather(DistTestCase):
@@ -35,10 +35,10 @@ def _run(self):
data3 = torch.tensor(8)
result1 = evenly_divisible_all_gather(data=data1, concat=True)
- torch.testing.assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))
+ assert_allclose(result1, torch.tensor([[1, 2], [3, 4], [5, 6]]))
result2 = evenly_divisible_all_gather(data=data2, concat=False)
for r, e in zip(result2, [torch.tensor([[1.0, 2.0]]), torch.tensor([[3.0, 4.0], [5.0, 6.0]])]):
- torch.testing.assert_allclose(r, e)
+ assert_allclose(r, e)
result3 = evenly_divisible_all_gather(data=data3, concat=False)
for r in result3:
self.assertEqual(r.ndimension(), 0)
diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py
index 4292ff3a226..688c65005e6 100644
--- a/tests/test_fill_holes.py
+++ b/tests/test_fill_holes.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
diff --git a/tests/test_fill_holesd.py b/tests/test_fill_holesd.py
index fce90fd86ac..7711df36b38 100644
--- a/tests/test_fill_holesd.py
+++ b/tests/test_fill_holesd.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
diff --git a/tests/test_fl_exchange_object.py b/tests/test_fl_exchange_object.py
new file mode 100644
index 00000000000..bb2d0372db1
--- /dev/null
+++ b/tests/test_fl_exchange_object.py
@@ -0,0 +1,61 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.fl.utils.constants import WeightType
+from monai.fl.utils.exchange_object import ExchangeObject
+from monai.utils.module import optional_import
+from tests.utils import SkipIfNoModule
+
+models, has_torchvision = optional_import("torchvision.models")
+
+TEST_INIT_1 = [{"weights": None, "optim": None, "metrics": None, "weight_type": None, "statistics": None}, "{}"]
+TEST_INIT_2: list = []
+if has_torchvision:
+ network = models.resnet18()
+ TEST_INIT_2.append(
+ {
+ "weights": network.state_dict(),
+ "optim": torch.optim.Adam(lr=1, params=network.parameters()).state_dict(),
+ "metrics": {"accuracy": 1},
+ "weight_type": WeightType.WEIGHT_DIFF,
+ "statistics": {"some_stat": 1},
+ }
+ )
+ TEST_INIT_2.append("{'weights': 122, 'optim': 2, 'metrics': 1, 'weight_type': fl_weight_diff, 'statistics': 1}")
+
+TEST_FAILURE_METRICS = [{"weights": None, "optim": None, "metrics": 1, "weight_type": None, "statistics": None}]
+TEST_FAILURE_STATISTICS = [{"weights": None, "optim": None, "metrics": None, "weight_type": None, "statistics": 1}]
+TEST_FAILURE_WEIGHT_TYPE = [{"weights": None, "optim": None, "metrics": None, "weight_type": 1, "statistics": None}]
+
+
+@SkipIfNoModule("torchvision")
+@SkipIfNoModule("ignite")
+class TestFLExchangeObject(unittest.TestCase):
+ @parameterized.expand([TEST_INIT_1, TEST_INIT_2])
+ def test_init(self, input_params, expected_str):
+ eo = ExchangeObject(**input_params)
+ self.assertIsInstance(eo, ExchangeObject)
+ eo.summary()
+ self.assertEqual(repr(eo), expected_str)
+
+ @parameterized.expand([TEST_FAILURE_METRICS, TEST_FAILURE_STATISTICS, TEST_FAILURE_WEIGHT_TYPE])
+ def test_failures(self, input_params):
+ with self.assertRaises(ValueError):
+ ExchangeObject(**input_params)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py
new file mode 100644
index 00000000000..0627235a184
--- /dev/null
+++ b/tests/test_fl_monai_algo.py
@@ -0,0 +1,202 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+from parameterized import parameterized
+
+from monai.bundle import ConfigParser
+from monai.fl.client.monai_algo import MonaiAlgo
+from monai.fl.utils.constants import ExtraItems
+from monai.fl.utils.exchange_object import ExchangeObject
+from tests.utils import SkipIfNoModule
+
+_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
+_data_dir = os.path.join(_root_dir, "testing_data")
+
+TEST_TRAIN_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
+ "config_evaluate_filename": None,
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_TRAIN_2 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
+ "config_evaluate_filename": None,
+ "config_filters_filename": None,
+ }
+]
+TEST_TRAIN_3 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": [
+ os.path.join(_data_dir, "config_fl_train.json"),
+ os.path.join(_data_dir, "config_fl_train.json"),
+ ],
+ "config_evaluate_filename": None,
+ "config_filters_filename": [
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ ],
+ }
+]
+
+TEST_EVALUATE_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": None,
+ "config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_EVALUATE_2 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": None,
+ "config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
+ "config_filters_filename": None,
+ }
+]
+TEST_EVALUATE_3 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": None,
+ "config_evaluate_filename": [
+ os.path.join(_data_dir, "config_fl_evaluate.json"),
+ os.path.join(_data_dir, "config_fl_evaluate.json"),
+ ],
+ "config_filters_filename": [
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ ],
+ }
+]
+
+TEST_GET_WEIGHTS_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
+ "config_evaluate_filename": None,
+ "send_weight_diff": False,
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_GET_WEIGHTS_2 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": None,
+ "config_evaluate_filename": None,
+ "send_weight_diff": False,
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_GET_WEIGHTS_3 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
+ "config_evaluate_filename": None,
+ "send_weight_diff": True,
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_GET_WEIGHTS_4 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": [
+ os.path.join(_data_dir, "config_fl_train.json"),
+ os.path.join(_data_dir, "config_fl_train.json"),
+ ],
+ "config_evaluate_filename": None,
+ "send_weight_diff": True,
+ "config_filters_filename": [
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ ],
+ }
+]
+
+
+@SkipIfNoModule("ignite")
+class TestFLMonaiAlgo(unittest.TestCase):
+ @parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3])
+ def test_train(self, input_params):
+ # get testing data dir and update train config; using the first to define data dir
+ if isinstance(input_params["config_train_filename"], list):
+ config_train_filename = [
+ os.path.join(input_params["bundle_root"], x) for x in input_params["config_train_filename"]
+ ]
+ else:
+ config_train_filename = os.path.join(input_params["bundle_root"], input_params["config_train_filename"])
+
+ # initialize algo
+ algo = MonaiAlgo(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
+ algo.abort()
+
+ # initialize model
+ parser = ConfigParser()
+ parser.read_config(config_train_filename)
+ parser.parse()
+ network = parser.get_parsed_content("network")
+
+ data = ExchangeObject(weights=network.state_dict())
+
+ # test train
+ algo.train(data=data, extra={})
+ algo.finalize()
+
+ @parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])
+ def test_evaluate(self, input_params):
+ # get testing data dir and update train config; using the first to define data dir
+ if isinstance(input_params["config_evaluate_filename"], list):
+ config_eval_filename = [
+ os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]
+ ]
+ else:
+ config_eval_filename = os.path.join(input_params["bundle_root"], input_params["config_evaluate_filename"])
+
+ # initialize algo
+ algo = MonaiAlgo(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
+
+ # initialize model
+ parser = ConfigParser()
+ parser.read_config(config_eval_filename)
+ parser.parse()
+ network = parser.get_parsed_content("network")
+
+ data = ExchangeObject(weights=network.state_dict())
+
+ # test evaluate
+ algo.evaluate(data=data, extra={})
+
+ @parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3, TEST_GET_WEIGHTS_4])
+ def test_get_weights(self, input_params):
+ # initialize algo
+ algo = MonaiAlgo(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
+
+ # test train
+ if input_params["send_weight_diff"]: # should not work as test doesn't receive a global model
+ with self.assertRaises(ValueError):
+ weights = algo.get_weights(extra={})
+ else:
+ weights = algo.get_weights(extra={})
+ self.assertIsInstance(weights, ExchangeObject)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_fl_monai_algo_dist.py b/tests/test_fl_monai_algo_dist.py
new file mode 100644
index 00000000000..11f64ea318f
--- /dev/null
+++ b/tests/test_fl_monai_algo_dist.py
@@ -0,0 +1,97 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from os.path import join as pathjoin
+
+import torch.distributed as dist
+from parameterized import parameterized
+
+from monai.bundle import ConfigParser
+from monai.fl.client.monai_algo import MonaiAlgo
+from monai.fl.utils.constants import ExtraItems
+from monai.fl.utils.exchange_object import ExchangeObject
+from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion, SkipIfNoModule, skip_if_no_cuda
+
+_root_dir = os.path.abspath(pathjoin(os.path.dirname(__file__)))
+_data_dir = pathjoin(_root_dir, "testing_data")
+TEST_TRAIN_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": [
+ pathjoin(_data_dir, "config_fl_train.json"),
+ pathjoin(_data_dir, "multi_gpu_train.json"),
+ ],
+ "config_evaluate_filename": None,
+ "config_filters_filename": pathjoin(_root_dir, "testing_data", "config_fl_filters.json"),
+ "multi_gpu": True,
+ }
+]
+
+TEST_EVALUATE_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": None,
+ "config_evaluate_filename": [
+ pathjoin(_data_dir, "config_fl_evaluate.json"),
+ pathjoin(_data_dir, "multi_gpu_evaluate.json"),
+ ],
+ "config_filters_filename": pathjoin(_data_dir, "config_fl_filters.json"),
+ "multi_gpu": True,
+ }
+]
+
+
+@SkipIfNoModule("ignite")
+@SkipIfBeforePyTorchVersion((1, 11, 1))
+class TestFLMonaiAlgo(DistTestCase):
+ @parameterized.expand([TEST_TRAIN_1])
+ @DistCall(nnodes=1, nproc_per_node=2, init_method="no_init")
+ @skip_if_no_cuda
+ def test_train(self, input_params):
+ # initialize algo
+ algo = MonaiAlgo(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
+ self.assertTrue(dist.get_rank() in (0, 1))
+
+ # initialize model
+ parser = ConfigParser()
+ parser.read_config([pathjoin(input_params["bundle_root"], x) for x in input_params["config_train_filename"]])
+ parser.parse()
+ network = parser.get_parsed_content("network")
+ data = ExchangeObject(weights=network.state_dict())
+ # test train
+ algo.train(data=data, extra={})
+
+ @parameterized.expand([TEST_EVALUATE_1])
+ @DistCall(nnodes=1, nproc_per_node=2, init_method="no_init")
+ @skip_if_no_cuda
+ def test_evaluate(self, input_params):
+ # initialize algo
+ algo = MonaiAlgo(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
+ self.assertTrue(dist.get_rank() in (0, 1))
+
+ # initialize model
+ parser = ConfigParser()
+ parser.read_config(
+ [os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]]
+ )
+ parser.parse()
+ network = parser.get_parsed_content("network")
+ data = ExchangeObject(weights=network.state_dict())
+ # test evaluate
+ algo.evaluate(data=data, extra={})
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_fl_monai_algo_stats.py b/tests/test_fl_monai_algo_stats.py
new file mode 100644
index 00000000000..fd2b73ea85e
--- /dev/null
+++ b/tests/test_fl_monai_algo_stats.py
@@ -0,0 +1,69 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+from parameterized import parameterized
+
+from monai.fl.client import MonaiAlgoStats
+from monai.fl.utils.constants import ExtraItems, FlStatistics
+from monai.fl.utils.exchange_object import ExchangeObject
+from tests.utils import SkipIfNoModule
+
+_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
+_data_dir = os.path.join(_root_dir, "testing_data")
+
+TEST_GET_DATA_STATS_1 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_stats_1.json"),
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_GET_DATA_STATS_2 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": os.path.join(_data_dir, "config_fl_stats_2.json"),
+ "config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
+ }
+]
+TEST_GET_DATA_STATS_3 = [
+ {
+ "bundle_root": _data_dir,
+ "config_train_filename": [
+ os.path.join(_data_dir, "config_fl_stats_1.json"),
+ os.path.join(_data_dir, "config_fl_stats_2.json"),
+ ],
+ "config_filters_filename": [
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ os.path.join(_data_dir, "config_fl_filters.json"),
+ ],
+ }
+]
+
+
+@SkipIfNoModule("ignite")
+class TestFLMonaiAlgo(unittest.TestCase):
+ @parameterized.expand([TEST_GET_DATA_STATS_1, TEST_GET_DATA_STATS_2, TEST_GET_DATA_STATS_3])
+ def test_get_data_stats(self, input_params):
+ # initialize algo
+ algo = MonaiAlgoStats(**input_params)
+ algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl", ExtraItems.APP_ROOT: _data_dir})
+
+ requested_stats = {FlStatistics.HIST_BINS: 100, FlStatistics.HIST_RANGE: [-500, 500]}
+ # test train
+ stats = algo.get_data_stats(extra=requested_stats)
+ self.assertIsInstance(stats, ExchangeObject)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py
new file mode 100644
index 00000000000..b71afc80cff
--- /dev/null
+++ b/tests/test_flexible_unet.py
@@ -0,0 +1,422 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from typing import Union
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.blocks.encoder import BaseEncoder
+from monai.networks.nets import (
+ FLEXUNET_BACKBONE,
+ EfficientNetBNFeatures,
+ FlexibleUNet,
+ FlexUNetEncoderRegister,
+ ResNet,
+ ResNetBlock,
+ ResNetBottleneck,
+)
+from monai.utils import optional_import
+from tests.utils import skip_if_downloading_fails, skip_if_quick
+
+torchvision, has_torchvision = optional_import("torchvision")
+PIL, has_pil = optional_import("PIL")
+
+
+class DummyEncoder(BaseEncoder):
+ @classmethod
+ def get_encoder_parameters(cls):
+ basic_dict = {"spatial_dims": 2, "in_channels": 3, "pretrained": False}
+ param_dict_list = [basic_dict]
+ for key in basic_dict:
+ cur_dict = basic_dict.copy()
+ del cur_dict[key]
+ param_dict_list.append(cur_dict)
+ return param_dict_list
+
+ @classmethod
+ def num_channels_per_output(cls):
+
+ return [(32, 64, 128, 256, 512, 1024), (32, 64, 128, 256), (32, 64, 128, 256), (32, 64, 128, 256)]
+
+ @classmethod
+ def num_outputs(cls):
+
+ return [6, 4, 4, 4]
+
+ @classmethod
+ def get_encoder_names(cls):
+ return ["encoder_wrong_channels", "encoder_no_param1", "encoder_no_param2", "encoder_no_param3"]
+
+
+class ResNetEncoder(ResNet, BaseEncoder):
+ backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
+ output_feature_channels = [(64, 128, 256, 512)] * 3 + [(256, 512, 1024, 2048)] * 4
+ parameter_layers = [
+ [1, 1, 1, 1],
+ [2, 2, 2, 2],
+ [3, 4, 6, 3],
+ [3, 4, 6, 3],
+ [3, 4, 23, 3],
+ [3, 8, 36, 3],
+ [3, 24, 36, 3],
+ ]
+
+ def __init__(self, in_channels, pretrained, **kargs):
+ super().__init__(**kargs, n_input_channels=in_channels)
+ if pretrained:
+ # Author of paper zipped the state_dict on googledrive,
+ # so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
+ # Would like to load dict from url but need somewhere to save the state dicts.
+ raise NotImplementedError(
+ "Currently not implemented. You need to manually download weights provided by the paper's author"
+ " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
+ )
+
+ @staticmethod
+ def get_inplanes():
+ return [64, 128, 256, 512]
+
+ @classmethod
+ def get_encoder_parameters(cls):
+ """
+ Get parameter list to initialize encoder networks.
+ Each parameter dict must have `spatial_dims`, `in_channels`
+ and `pretrained` parameters.
+ """
+ parameter_list = []
+ res_type: Union[ResNetBlock, ResNetBottleneck]
+ for backbone in range(len(cls.backbone_names)):
+ if backbone < 3:
+ res_type = ResNetBlock
+ else:
+ res_type = ResNetBottleneck
+ parameter_list.append(
+ {
+ "block": res_type,
+ "layers": cls.parameter_layers[backbone],
+ "block_inplanes": ResNetEncoder.get_inplanes(),
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "pretrained": False,
+ }
+ )
+ return parameter_list
+
+ @classmethod
+ def num_channels_per_output(cls):
+ """
+ Get number of output features' channel.
+ """
+ return cls.output_feature_channels
+
+ @classmethod
+ def num_outputs(cls):
+ """
+ Get number of output feature.
+ """
+ return [4] * 7
+
+ @classmethod
+ def get_encoder_names(cls):
+ """
+ Get the name string of backbones which will be used to initialize flexible unet.
+ """
+ return cls.backbone_names
+
+ def forward(self, x: torch.Tensor):
+ feature_list = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ if not self.no_max_pool:
+ x = self.maxpool(x)
+ x = self.layer1(x)
+ feature_list.append(x)
+ x = self.layer2(x)
+ feature_list.append(x)
+ x = self.layer3(x)
+ feature_list.append(x)
+ x = self.layer4(x)
+ feature_list.append(x)
+
+ return feature_list
+
+
+FLEXUNET_BACKBONE.regist_class(ResNetEncoder)
+FLEXUNET_BACKBONE.regist_class(DummyEncoder)
+
+
+def get_model_names():
+ return [f"efficientnet-b{d}" for d in range(8)]
+
+
+def get_resnet_names():
+ return ResNetEncoder.get_encoder_names()
+
+
+def make_shape_cases(
+ models,
+ spatial_dims,
+ batches,
+ pretrained,
+ in_channels=3,
+ num_classes=10,
+ input_shape=64,
+ norm=("batch", {"eps": 1e-3, "momentum": 0.01}),
+):
+ ret_tests = []
+ for spatial_dim in spatial_dims: # selected spatial_dims
+ for batch in batches: # check single batch as well as multiple batch input
+ for model in models: # selected models
+ for is_pretrained in pretrained: # pretrained or not pretrained
+ if ("resnet" in model) and is_pretrained:
+ continue
+ kwargs = {
+ "in_channels": in_channels,
+ "out_channels": num_classes,
+ "backbone": model,
+ "pretrained": is_pretrained,
+ "spatial_dims": spatial_dim,
+ "norm": norm,
+ }
+ ret_tests.append(
+ [
+ kwargs,
+ (batch, in_channels) + (input_shape,) * spatial_dim,
+ (batch, num_classes) + (input_shape,) * spatial_dim,
+ ]
+ )
+ return ret_tests
+
+
+def make_error_case():
+ error_dummy_backbones = DummyEncoder.get_encoder_names()
+ error_resnet_backbones = ResNetEncoder.get_encoder_names()
+ error_backbones = error_dummy_backbones + error_resnet_backbones
+ error_param_list = []
+ for backbone in error_backbones:
+ error_param_list.append(
+ [{"in_channels": 3, "out_channels": 2, "backbone": backbone, "pretrained": True, "spatial_dims": 3}]
+ )
+ return error_param_list
+
+
+# create list of selected models to speed up redundant tests
+# only test efficient net B0, B3 and resnet 10, 18, 34
+SEL_MODELS = [get_model_names()[i] for i in [0, 3]]
+SEL_MODELS += [get_resnet_names()[i] for i in [0, 1, 2]]
+
+# pretrained=False cases
+# 2D and 3D models are expensive so use selected models
+CASES_2D = make_shape_cases(
+ models=SEL_MODELS,
+ spatial_dims=[2],
+ batches=[1, 4],
+ pretrained=[False],
+ in_channels=3,
+ num_classes=10,
+ norm="instance",
+)
+CASES_3D = make_shape_cases(
+ models=[SEL_MODELS[0]],
+ spatial_dims=[3],
+ batches=[1],
+ pretrained=[False],
+ in_channels=3,
+ num_classes=10,
+ norm="batch",
+)
+
+# varying num_classes and in_channels
+CASES_VARIATIONS = []
+
+# change num_classes test
+# 20 classes
+# 2D
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=3, num_classes=20
+ )
+)
+# 3D
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=3, num_classes=20
+ )
+)
+
+# change in_channels test
+# 1 channel
+# 2D
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=1, num_classes=10
+ )
+)
+# 8 channel
+# 2D
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS, spatial_dims=[2], batches=[1], pretrained=[False, True], in_channels=8, num_classes=10
+ )
+)
+# 3D
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=[SEL_MODELS[0]], spatial_dims=[3], batches=[1], pretrained=[False], in_channels=1, num_classes=10
+ )
+)
+
+# change input shape test
+# 96
+# 2D 96x96 input
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS,
+ spatial_dims=[2],
+ batches=[1],
+ pretrained=[False, True],
+ in_channels=3,
+ num_classes=10,
+ input_shape=96,
+ )
+)
+# 2D 64x64 input
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS,
+ spatial_dims=[2],
+ batches=[1],
+ pretrained=[False, True],
+ in_channels=3,
+ num_classes=10,
+ input_shape=64,
+ )
+)
+
+# 3D 32x32x32 input
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS,
+ spatial_dims=[2],
+ batches=[1],
+ pretrained=[False],
+ in_channels=3,
+ num_classes=10,
+ input_shape=32,
+ )
+)
+
+# 3D 64x64x64 input
+CASES_VARIATIONS.extend(
+ make_shape_cases(
+ models=SEL_MODELS,
+ spatial_dims=[2],
+ batches=[1],
+ pretrained=[False],
+ in_channels=3,
+ num_classes=10,
+ input_shape=64,
+ )
+)
+
+# pretrain weight verified
+CASES_PRETRAIN = [
+ (
+ {
+ "in_channels": 3,
+ "out_channels": 10,
+ "backbone": SEL_MODELS[0],
+ "pretrained": True,
+ "spatial_dims": 2,
+ "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
+ },
+ {
+ "in_channels": 3,
+ "num_classes": 10,
+ "model_name": SEL_MODELS[0],
+ "pretrained": True,
+ "spatial_dims": 2,
+ "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
+ },
+ ["_conv_stem.weight"],
+ )
+]
+
+CASE_ERRORS = make_error_case()
+
+# Verify Register class with string type
+CASE_REGISTER_ENCODER = ["EfficientNetEncoder", "monai.networks.nets.EfficientNetEncoder"]
+
+
+@skip_if_quick
+class TestFLEXIBLEUNET(unittest.TestCase):
+ @parameterized.expand(CASES_2D + CASES_3D + CASES_VARIATIONS)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ with skip_if_downloading_fails():
+ net = FlexibleUNet(**input_param).to(device)
+
+ # run inference with random tensor
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device))
+
+ # check output shape
+ self.assertEqual(result.shape, expected_shape)
+
+ @parameterized.expand(CASES_PRETRAIN)
+ def test_pretrain(self, input_param, efficient_input_param, weight_list):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ with skip_if_downloading_fails():
+ net = FlexibleUNet(**input_param).to(device)
+
+ with skip_if_downloading_fails():
+ eff_net = EfficientNetBNFeatures(**efficient_input_param).to(device)
+
+ for weight_name in weight_list:
+ if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict():
+ net_weight = net.encoder.state_dict()[weight_name]
+ download_weight = eff_net.state_dict()[weight_name]
+ weight_diff = torch.abs(net_weight - download_weight)
+ diff_sum = torch.sum(weight_diff)
+ # check if a weight in weight_list equals to the downloaded weight.
+ self.assertLess(abs(diff_sum.item() - 0), 1e-8)
+
+ @parameterized.expand(CASE_ERRORS)
+ def test_error_raise(self, input_param):
+ with self.assertRaises((ValueError, NotImplementedError)):
+ FlexibleUNet(**input_param)
+
+
+class TestFlexUNetEncoderRegister(unittest.TestCase):
+ @parameterized.expand(CASE_REGISTER_ENCODER)
+ def test_regist(self, encoder):
+ tmp_backbone = FlexUNetEncoderRegister()
+ tmp_backbone.regist_class(encoder)
+ for backbone in tmp_backbone.register_dict:
+ backbone_type = tmp_backbone.register_dict[backbone]["type"]
+ feature_number = backbone_type.num_outputs()
+ feature_channel = backbone_type.num_channels_per_output()
+ param_dict_list = backbone_type.get_encoder_parameters()
+ encoder_name_list = backbone_type.get_encoder_names()
+ encoder_cnt = encoder_name_list.index(backbone)
+ self.assertEqual(feature_number[encoder_cnt], tmp_backbone.register_dict[backbone]["feature_number"])
+ self.assertEqual(feature_channel[encoder_cnt], tmp_backbone.register_dict[backbone]["feature_channel"])
+ self.assertEqual(param_dict_list[encoder_cnt], tmp_backbone.register_dict[backbone]["parameter"])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_fpn_block.py b/tests/test_fpn_block.py
index 420fd04367c..a86cd22a19d 100644
--- a/tests/test_fpn_block.py
+++ b/tests/test_fpn_block.py
@@ -19,7 +19,7 @@
from monai.networks.blocks.feature_pyramid_network import FeaturePyramidNetwork
from monai.networks.nets.resnet import resnet50
from monai.utils import optional_import
-from tests.utils import test_script_save
+from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
_, has_torchvision = optional_import("torchvision")
@@ -53,6 +53,7 @@ def test_fpn_block(self, input_param, input_shape, expected_shape):
self.assertEqual(result["feat1"].shape, expected_shape[1])
@parameterized.expand(TEST_CASES)
+ @SkipIfBeforePyTorchVersion((1, 9, 1))
def test_script(self, input_param, input_shape, expected_shape):
# test whether support torchscript
net = FeaturePyramidNetwork(**input_param)
@@ -73,6 +74,7 @@ def test_fpn(self, input_param, input_shape, expected_shape):
self.assertEqual(result["pool"].shape, expected_shape[1])
@parameterized.expand(TEST_CASES2)
+ @SkipIfBeforePyTorchVersion((1, 9, 1))
def test_script(self, input_param, input_shape, expected_shape):
# test whether support torchscript
net = _resnet_fpn_extractor(backbone=resnet50(), spatial_dims=input_param["spatial_dims"], returned_layers=[1])
diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py
index 81f8f4c0b0a..619814037bc 100644
--- a/tests/test_generalized_dice_loss.py
+++ b/tests/test_generalized_dice_loss.py
@@ -46,7 +46,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
- 0.469964,
+ 0.435035,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -54,7 +54,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
- 0.414507,
+ 0.3837,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
@@ -69,7 +69,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
- 0.829015,
+ 1.5348,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
@@ -84,7 +84,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
- [[[0.273476]], [[0.555539]]],
+ [[[0.210949], [0.295351]], [[0.599976], [0.428522]]],
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8},
@@ -112,7 +112,7 @@
"input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1, 1, 0, 0]]]),
},
- 0.250023,
+ 0.26669,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -134,7 +134,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
- -0.097833,
+ -8.55485,
],
]
diff --git a/tests/test_generate_distance_map.py b/tests/test_generate_distance_map.py
new file mode 100644
index 00000000000..0be252dbf83
--- /dev/null
+++ b/tests/test_generate_distance_map.py
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateDistanceMap
+from monai.transforms.intensity.array import GaussianSmooth
+from tests.utils import TEST_NDARRAYS
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError])
+
+ EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])
+
+for p in TEST_NDARRAYS:
+ TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])
+ TESTS.append(
+ [{"smooth_fn": GaussianSmooth(sigma=0.4)}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]
+ )
+
+
+class TestGenerateDistanceMap(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, probmap, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateDistanceMap(**argments)(mask, probmap)
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, probmap, expected_shape):
+ result = GenerateDistanceMap(**argments)(mask, probmap)
+ self.assertEqual(result.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_distance_mapd.py b/tests/test_generate_distance_mapd.py
new file mode 100644
index 00000000000..fb6e59f36b0
--- /dev/null
+++ b/tests/test_generate_distance_mapd.py
@@ -0,0 +1,62 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateDistanceMapd
+from monai.transforms.intensity.array import GaussianSmooth
+from tests.utils import TEST_NDARRAYS
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError]
+ )
+
+ EXCEPTION_TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]
+ )
+ TESTS.append(
+ [
+ {"keys": "mask", "border_key": "border", "smooth_fn": GaussianSmooth(sigma=0.4)},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(1, 5, 5)),
+ (1, 5, 5),
+ ]
+ )
+
+
+class TestGenerateDistanceMapd(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, border_map, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateDistanceMapd(**argments)({"mask": mask, "border": border_map})
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, border_map, expected_shape):
+ result = GenerateDistanceMapd(**argments)({"mask": mask, "border": border_map})
+ self.assertEqual(result["dist"].shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_border.py b/tests/test_generate_instance_border.py
new file mode 100644
index 00000000000..1cb7e39c31b
--- /dev/null
+++ b/tests/test_generate_instance_border.py
@@ -0,0 +1,85 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateInstanceBorder
+from tests.utils import TEST_NDARRAYS
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [
+ {"kernel_size": 3, "remove_small_objects": False},
+ p(np.random.rand(1, 5, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ ValueError,
+ ]
+ )
+
+ EXCEPTION_TESTS.append(
+ [
+ {"kernel_size": 3, "remove_small_objects": False},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(1, 5, 5)),
+ ValueError,
+ ]
+ )
+
+ EXCEPTION_TESTS.append(
+ [
+ {"kernel_size": 3, "remove_small_objects": False},
+ p(np.random.rand(2, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ ValueError,
+ ]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [
+ {"kernel_size": 3, "remove_small_objects": False},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ (1, 5, 5),
+ ]
+ )
+ TESTS.append(
+ [
+ {"kernel_size": 3, "remove_small_objects": False},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ (1, 5, 5),
+ ]
+ )
+
+
+class TestGenerateInstanceBorder(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, hover_map, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateInstanceBorder(**argments)(mask, hover_map)
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, hover_map, expected_shape):
+ result = GenerateInstanceBorder(**argments)(mask, hover_map)
+ self.assertEqual(result.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_borderd.py b/tests/test_generate_instance_borderd.py
new file mode 100644
index 00000000000..a4ee5221a64
--- /dev/null
+++ b/tests/test_generate_instance_borderd.py
@@ -0,0 +1,85 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceBorderd
+from tests.utils import TEST_NDARRAYS
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [
+ {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(1, 5, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ ValueError,
+ ]
+ )
+
+ EXCEPTION_TESTS.append(
+ [
+ {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(1, 5, 5)),
+ ValueError,
+ ]
+ )
+
+ EXCEPTION_TESTS.append(
+ [
+ {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(2, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ ValueError,
+ ]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [
+ {"keys": "mask", "kernel_size": 3, "remove_small_objects": False, "min_size": 10},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ (1, 5, 5),
+ ]
+ )
+ TESTS.append(
+ [
+ {"keys": "mask", "kernel_size": 3, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(2, 5, 5)),
+ (1, 5, 5),
+ ]
+ )
+
+
+class TestGenerateInstanceBorderd(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, hover_map, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateInstanceBorderd(**argments)({"mask": mask, "hover_map": hover_map})
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, hover_map, expected_shape):
+ result = GenerateInstanceBorderd(**argments)({"mask": mask, "hover_map": hover_map})
+ self.assertEqual(result["border"].shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_centroid.py b/tests/test_generate_instance_centroid.py
new file mode 100644
index 00000000000..46f94be6371
--- /dev/null
+++ b/tests/test_generate_instance_centroid.py
@@ -0,0 +1,52 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateInstanceCentroid
+from monai.transforms import BoundingRect
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+
+y, x = np.ogrid[0:30, 0:30]
+get_bbox = BoundingRect()
+
+TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]]
+
+TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]]
+
+TEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]]
+
+
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+ TEST_CASE.append([p, *TEST_CASE_3])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class TestGenerateInstanceCentroid(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, test_data, offset, expected):
+ inst_bbox = get_bbox(test_data[None])
+ inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
+ result = GenerateInstanceCentroid()(in_type(inst_map[None]), offset=offset)
+ assert_allclose(result, expected, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_centroidd.py b/tests/test_generate_instance_centroidd.py
new file mode 100644
index 00000000000..f989de5ff29
--- /dev/null
+++ b/tests/test_generate_instance_centroidd.py
@@ -0,0 +1,54 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceCentroidd
+from monai.transforms import BoundingRect
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+
+y, x = np.ogrid[0:30, 0:30]
+get_bbox = BoundingRect()
+
+TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, [0, 0], [2, 2]]
+
+TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, [6, 6], [8, 8]]
+
+TEST_CASE_3 = [(x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1, [2, 3], [4, 6]]
+
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+ TEST_CASE.append([p, *TEST_CASE_3])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class TestGenerateInstanceCentroidd(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, test_data, offset, expected):
+ inst_bbox = get_bbox(test_data[None])
+ inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
+ test_case = {"image": in_type(inst_map[None]), "offset": offset}
+ result = GenerateInstanceCentroidd(keys="image", centroid_key_postfix="centroid", offset_key="offset")(
+ test_case
+ )
+ assert_allclose(result["image_centroid"], expected, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_contour.py b/tests/test_generate_instance_contour.py
new file mode 100644
index 00000000000..22b778c06c6
--- /dev/null
+++ b/tests/test_generate_instance_contour.py
@@ -0,0 +1,57 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateInstanceContour
+from monai.transforms import BoundingRect
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+
+y, x = np.ogrid[0:30, 0:30]
+get_bbox = BoundingRect()
+
+TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]]
+
+TEST_CASE_2 = [(x - 8) ** 2 + (y - 8) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]]
+
+TEST_CASE_3 = [
+ (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1,
+ 3,
+ [2, 3],
+ [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]],
+]
+
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+ TEST_CASE.append([p, *TEST_CASE_3])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class TestGenerateInstanceContour(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, test_data, points_num, offset, expected):
+
+ inst_bbox = get_bbox(test_data[None])
+ inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
+ result = GenerateInstanceContour(points_num=points_num)(in_type(inst_map[None]), offset=offset)
+ assert_allclose(result, expected, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_contourd.py b/tests/test_generate_instance_contourd.py
new file mode 100644
index 00000000000..9c9c1efbe68
--- /dev/null
+++ b/tests/test_generate_instance_contourd.py
@@ -0,0 +1,60 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceContourd
+from monai.transforms import BoundingRect
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+
+y, x = np.ogrid[0:30, 0:30]
+get_bbox = BoundingRect()
+
+TEST_CASE_1 = [(x - 2) ** 2 + (y - 2) ** 2 <= 2**2, 3, [0, 0], [[2, 0], [0, 2], [2, 4], [4, 2]]]
+
+TEST_CASE_2 = [(x - 10) ** 2 + (y - 10) ** 2 <= 2**2, 3, [8, 8], [[10, 8], [8, 10], [10, 12], [12, 10]]]
+
+
+TEST_CASE_3 = [
+ (x - 5) ** 2 / 3**2 + (y - 5) ** 2 / 2**2 <= 1,
+ 3,
+ [2, 3],
+ [[5, 3], [4, 4], [3, 4], [2, 5], [3, 6], [4, 6], [5, 7], [6, 6], [7, 6], [8, 5], [7, 4], [6, 4]],
+]
+
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+ TEST_CASE.append([p, *TEST_CASE_3])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+class TestGenerateInstanceContourd(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, test_data, points_num, offset, expected):
+ inst_bbox = get_bbox(test_data[None])
+ inst_map = test_data[inst_bbox[0][0] : inst_bbox[0][1], inst_bbox[0][2] : inst_bbox[0][3]]
+ test_data = {"image": in_type(inst_map[None]), "offset": offset}
+ result = GenerateInstanceContourd(
+ keys="image", contour_key_postfix="contour", offset_key="offset", points_num=points_num
+ )(test_data)
+ assert_allclose(result["image_contour"], expected, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_type.py b/tests/test_generate_instance_type.py
new file mode 100644
index 00000000000..8a083d19b76
--- /dev/null
+++ b/tests/test_generate_instance_type.py
@@ -0,0 +1,49 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateInstanceType
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+y, x = np.ogrid[0:30, 0:30]
+
+TEST_CASE_1 = [
+ (x - 2) ** 2 + (y - 2) ** 2 <= 2**2,
+ (x - 2) ** 2 + (y - 3) ** 2 <= 2**2,
+ np.array([[0, 5, 0, 5]]),
+ [1, 0.6666666111111158],
+]
+
+TEST_CASE_2 = [
+ (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1,
+ (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1,
+ np.array([[6, 11, 5, 12]]),
+ [1, 0.7058823114186875],
+]
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+
+
+class TestGenerateInstanceType(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):
+ result = GenerateInstanceType()(in_type(type_pred[None]), in_type(seg_pred[None]), bbox, 1)
+ assert_allclose(result, expected)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_instance_typed.py b/tests/test_generate_instance_typed.py
new file mode 100644
index 00000000000..08d9f550a9d
--- /dev/null
+++ b/tests/test_generate_instance_typed.py
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateInstanceTyped
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+y, x = np.ogrid[0:30, 0:30]
+
+TEST_CASE_1 = [
+ (x - 2) ** 2 + (y - 2) ** 2 <= 2**2,
+ (x - 2) ** 2 + (y - 3) ** 2 <= 2**2,
+ np.array([[0, 5, 0, 5]]),
+ [1, 0.6666666111111158],
+]
+
+TEST_CASE_2 = [
+ (x - 8) ** 2 / 3**2 + (y - 8) ** 2 / 2**2 <= 1,
+ (x - 7) ** 2 / 3**2 + (y - 7) ** 2 / 2**2 <= 1,
+ np.array([[6, 11, 5, 12]]),
+ [1, 0.7058823114186875],
+]
+TEST_CASE = []
+for p in TEST_NDARRAYS:
+ TEST_CASE.append([p, *TEST_CASE_1])
+ TEST_CASE.append([p, *TEST_CASE_2])
+
+
+class TestGenerateInstanceTyped(unittest.TestCase):
+ @parameterized.expand(TEST_CASE)
+ def test_shape(self, in_type, type_pred, seg_pred, bbox, expected):
+ test_data = {"type_pred": in_type(type_pred[None]), "seg": in_type(seg_pred[None]), "bbox": bbox, "id": 1}
+ result = GenerateInstanceTyped(keys="type_pred")(test_data)
+ assert_allclose(result["type_info"]["inst_type"], expected[0])
+ assert_allclose(result["type_info"]["type_prob"], expected[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_param_groups.py b/tests/test_generate_param_groups.py
index 0b259442ea3..7ae42b8ec68 100644
--- a/tests/test_generate_param_groups.py
+++ b/tests/test_generate_param_groups.py
@@ -17,6 +17,7 @@
from monai.networks.nets import Unet
from monai.optimizers import generate_param_groups
from monai.utils import ensure_tuple
+from tests.utils import assert_allclose
TEST_CASE_1 = [{"layer_matches": [lambda x: x.model[-1]], "match_types": "select", "lr_values": [1]}, (1, 100), [5, 21]]
@@ -76,7 +77,7 @@ def test_lr_values(self, input_param, expected_values, expected_groups):
optimizer = torch.optim.Adam(params, 100)
for param_group, value in zip(optimizer.param_groups, ensure_tuple(expected_values)):
- torch.testing.assert_allclose(param_group["lr"], value)
+ assert_allclose(param_group["lr"], value)
n = [len(p["params"]) for p in params]
self.assertListEqual(n, expected_groups)
diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py
index 91db0e9d960..d1a208770f4 100644
--- a/tests/test_generate_pos_neg_label_crop_centers.py
+++ b/tests/test_generate_pos_neg_label_crop_centers.py
@@ -31,7 +31,20 @@
list,
2,
3,
- ]
+ ],
+ [
+ {
+ "spatial_size": [2, 2, 2],
+ "num_samples": 2,
+ "pos_ratio": 0.0,
+ "label_spatial_shape": [3, 3, 3],
+ "fg_indices": [],
+ "bg_indices": [3, 12, 21],
+ },
+ list,
+ 2,
+ 3,
+ ],
]
diff --git a/tests/test_generate_succinct_contour.py b/tests/test_generate_succinct_contour.py
new file mode 100644
index 00000000000..478c23b5229
--- /dev/null
+++ b/tests/test_generate_succinct_contour.py
@@ -0,0 +1,52 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateSuccinctContour
+
+TEST_CASE_1 = [
+ [
+ np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]),
+ np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]),
+ np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),
+ np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]),
+ ],
+ 5,
+ 5,
+ [[2, 0], [0, 2], [2, 4], [4, 2]],
+]
+
+TEST_CASE_2 = [
+ [
+ np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]),
+ np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]),
+ np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),
+ np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]),
+ ],
+ 5,
+ 7,
+ [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]],
+]
+
+
+class TestGenerateSuccinctContour(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_shape(self, test_data, height, width, expected):
+ result = GenerateSuccinctContour(height=height, width=width)(test_data)
+ np.testing.assert_allclose(result, expected)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_succinct_contourd.py b/tests/test_generate_succinct_contourd.py
new file mode 100644
index 00000000000..b34142ec0d0
--- /dev/null
+++ b/tests/test_generate_succinct_contourd.py
@@ -0,0 +1,54 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateSuccinctContourd
+
+y, x = np.ogrid[0:5, 0:5]
+TEST_CASE_1 = [
+ [
+ np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.0, 1.5]]),
+ np.array([[0.0, 2.5], [0.5, 3.0], [1.0, 3.5], [1.5, 4.0]]),
+ np.array([[4.0, 1.5], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),
+ np.array([[2.5, 4.0], [3.0, 3.5], [3.5, 3.0], [4.0, 2.5]]),
+ ],
+ 5,
+ 5,
+ [[2, 0], [0, 2], [2, 4], [4, 2]],
+]
+
+TEST_CASE_2 = [
+ [
+ np.array([[1.5, 0.0], [1.0, 0.5], [0.5, 1.0], [0.5, 2.0], [0.0, 2.5]]),
+ np.array([[0.0, 3.5], [0.5, 4.0], [0.5, 5.0], [1.0, 5.5], [1.5, 6.0]]),
+ np.array([[4.0, 2.5], [3.5, 2.0], [3.5, 1.0], [3.0, 0.5], [2.5, 0.0]]),
+ np.array([[2.5, 6.0], [3.0, 5.5], [3.5, 5.0], [3.5, 4.0], [4.0, 3.5]]),
+ ],
+ 5,
+ 7,
+ [[3, 0], [2, 1], [1, 1], [0, 2], [1, 3], [2, 3], [3, 4], [4, 3], [5, 3], [6, 2], [5, 1], [4, 1]],
+]
+
+
+class TestGenerateSuccinctContour(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_shape(self, data, height, width, expected):
+ test_data = {"contour": data}
+ result = GenerateSuccinctContourd(keys="contour", height=height, width=width)(test_data)
+ np.testing.assert_allclose(result["contour"], expected)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_watershed_markers.py b/tests/test_generate_watershed_markers.py
new file mode 100644
index 00000000000..7b046686e95
--- /dev/null
+++ b/tests/test_generate_watershed_markers.py
@@ -0,0 +1,53 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateWatershedMarkers
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append([{}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError])
+
+ EXCEPTION_TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError])
+
+for p in TEST_NDARRAYS:
+ TESTS.append([{}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestGenerateWatershedMarkers(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, probmap, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateWatershedMarkers(**argments)(mask, probmap)
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, probmap, expected_shape):
+ result = GenerateWatershedMarkers(**argments)(mask, probmap)
+ self.assertEqual(result.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_watershed_markersd.py b/tests/test_generate_watershed_markersd.py
new file mode 100644
index 00000000000..cccb20c985b
--- /dev/null
+++ b/tests/test_generate_watershed_markersd.py
@@ -0,0 +1,68 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMarkersd
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(2, 5, 5)), p(np.random.rand(1, 5, 5)), ValueError]
+ )
+
+ EXCEPTION_TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(2, 5, 5)), ValueError]
+ )
+
+ EXCEPTION_TESTS.append(
+ [
+ {"keys": "mask", "border_key": "border", "markers_key": "old_markers"},
+ p(np.random.rand(1, 5, 5)),
+ p(np.random.rand(1, 5, 5)),
+ KeyError,
+ ]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [{"keys": "mask", "border_key": "border"}, p(np.random.rand(1, 5, 5)), p(np.random.rand(1, 5, 5)), (1, 5, 5)]
+ )
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestGenerateWatershedMarkersd(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, mask, border_map, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateWatershedMarkersd(**argments)({"mask": mask, "border": border_map, "old_markers": 1})
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, mask, border_map, expected_shape):
+ result = GenerateWatershedMarkersd(**argments)({"mask": mask, "border": border_map})
+ self.assertEqual(result["markers"].shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_watershed_mask.py b/tests/test_generate_watershed_mask.py
new file mode 100644
index 00000000000..1e2d84c7d28
--- /dev/null
+++ b/tests/test_generate_watershed_mask.py
@@ -0,0 +1,81 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import GenerateWatershedMask
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [
+ {"softmax": False, "sigmoid": True, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(1, 5, 5)),
+ ValueError,
+ ]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [
+ {"softmax": True, "sigmoid": False, "threshold": None, "remove_small_objects": False, "min_size": 10},
+ p(
+ [
+ [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]],
+ [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]],
+ ]
+ ),
+ (1, 3, 3),
+ [0, 1],
+ ]
+ )
+
+ TESTS.append(
+ [
+ {"softmax": False, "sigmoid": True, "threshold": 0.5, "remove_small_objects": False, "min_size": 10},
+ p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]),
+ (1, 3, 3),
+ [0, 1],
+ ]
+ )
+
+
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestGenerateWatershedMask(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, image, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateWatershedMask(**argments)(image)
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, image, expected_shape, expected_value):
+ result = GenerateWatershedMask(**argments)(image)
+ self.assertEqual(result.shape, expected_shape)
+
+ if isinstance(result, torch.Tensor):
+ result = result.cpu().numpy()
+ self.assertEqual(np.unique(result).tolist(), expected_value)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_generate_watershed_maskd.py b/tests/test_generate_watershed_maskd.py
new file mode 100644
index 00000000000..4e3a2ee15cb
--- /dev/null
+++ b/tests/test_generate_watershed_maskd.py
@@ -0,0 +1,97 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import GenerateWatershedMaskd
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+EXCEPTION_TESTS = []
+TESTS = []
+
+np.random.RandomState(123)
+
+for p in TEST_NDARRAYS:
+ EXCEPTION_TESTS.append(
+ [
+ {"keys": "img", "softmax": False, "sigmoid": True, "remove_small_objects": True, "min_size": 10},
+ p(np.random.rand(1, 5, 5)),
+ ValueError,
+ ]
+ )
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [
+ {
+ "keys": "img",
+ "mask_key": "mask",
+ "softmax": True,
+ "sigmoid": False,
+ "threshold": None,
+ "remove_small_objects": False,
+ "min_size": 10,
+ },
+ p(
+ [
+ [[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [0.6134, 0.6389, 0.0680]],
+ [[0.5000, 0.3400, 0.9900], [0.8900, 0.5600, 0.2700], [0.6100, 0.6300, 0.0600]],
+ ]
+ ),
+ (1, 3, 3),
+ [0, 1],
+ ]
+ )
+
+ TESTS.append(
+ [
+ {
+ "keys": "img",
+ "mask_key": "mask",
+ "softmax": False,
+ "sigmoid": True,
+ "threshold": 0.5,
+ "remove_small_objects": False,
+ "min_size": 10,
+ },
+ p([[[0.5022, 0.3403, 0.9997], [0.8793, 0.5514, 0.2697], [-0.1134, -0.0389, -0.0680]]]),
+ (1, 3, 3),
+ [0, 1],
+ ]
+ )
+
+
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestGenerateWatershedMaskd(unittest.TestCase):
+ @parameterized.expand(EXCEPTION_TESTS)
+ def test_value(self, argments, image, exception_type):
+ with self.assertRaises(exception_type):
+ GenerateWatershedMaskd(**argments)({"img": image})
+
+ @parameterized.expand(TESTS)
+ def test_value2(self, argments, image, expected_shape, expected_value):
+ result = GenerateWatershedMaskd(**argments)({"img": image})
+ self.assertEqual(result["mask"].shape, expected_shape)
+
+ if isinstance(result["mask"], torch.Tensor):
+ result["mask"] = result["mask"].cpu().numpy()
+ self.assertEqual(np.unique(result["mask"]).tolist(), expected_value)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_get_unique_labels.py b/tests/test_get_unique_labels.py
index 9bc6f9b152f..67953a32059 100644
--- a/tests/test_get_unique_labels.py
+++ b/tests/test_get_unique_labels.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
diff --git a/tests/test_globalnet.py b/tests/test_globalnet.py
index ef0209e397e..4a3e9c124c8 100644
--- a/tests/test_globalnet.py
+++ b/tests/test_globalnet.py
@@ -40,7 +40,6 @@
],
]
-
TEST_CASES_GLOBAL_NET = [
[
{
diff --git a/tests/test_gmm.py b/tests/test_gmm.py
index f085dd916c7..ad5e383a6aa 100644
--- a/tests/test_gmm.py
+++ b/tests/test_gmm.py
@@ -18,8 +18,9 @@
import torch
from parameterized import parameterized
+from monai._extensions import load_module
from monai.networks.layers import GaussianMixtureModel
-from tests.utils import skip_if_no_cuda
+from tests.utils import skip_if_darwin, skip_if_no_cuda, skip_if_windows
TEST_CASES = [
[
@@ -256,10 +257,9 @@
]
-@skip_if_no_cuda
class GMMTestCase(unittest.TestCase):
def setUp(self):
- self._var = os.environ.get("TORCH_EXTENSIONS_DIR", None)
+ self._var = os.environ.get("TORCH_EXTENSIONS_DIR")
self.tempdir = tempfile.mkdtemp()
os.environ["TORCH_EXTENSIONS_DIR"] = self.tempdir
@@ -271,6 +271,7 @@ def tearDown(self) -> None:
shutil.rmtree(self.tempdir)
@parameterized.expand(TEST_CASES)
+ @skip_if_no_cuda
def test_cuda(self, test_case_description, mixture_count, class_count, features, labels, expected):
# Device to run on
@@ -297,6 +298,15 @@ def test_cuda(self, test_case_description, mixture_count, class_count, features,
# Ensure result are as expected
np.testing.assert_allclose(results, expected, atol=1e-3)
+ @skip_if_darwin
+ @skip_if_windows
+ def test_load(self):
+ if not torch.cuda.is_available():
+ with self.assertRaisesRegex(ImportError, ".*symbol.*"): # expecting import error if no cuda
+ load_module("gmm", {"CHANNEL_COUNT": 2, "MIXTURE_COUNT": 2, "MIXTURE_SIZE": 3}, verbose_build=True)
+ else:
+ load_module("gmm", {"CHANNEL_COUNT": 2, "MIXTURE_COUNT": 2, "MIXTURE_SIZE": 3}, verbose_build=True)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py
index 30680c8e314..4c81035210b 100644
--- a/tests/test_grid_dataset.py
+++ b/tests/test_grid_dataset.py
@@ -60,7 +60,7 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
- np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
+ np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -69,7 +69,9 @@ def test_loading_array(self):
np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2))
np.testing.assert_allclose(
item[0],
- np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
+ np.array(
+ [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
+ ),
rtol=1e-3,
)
np.testing.assert_allclose(
@@ -102,7 +104,7 @@ def test_loading_dict(self):
self.assertListEqual(item[0]["metadata"], ["test string", "test string"])
np.testing.assert_allclose(
item[0]["image"],
- np.array([[[[8.0577, 9.0577], [12.0577, 13.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]),
+ np.array([[[[8.240326, 9.240326], [12.240326, 13.240326]]], [[[10.1624, 11.1624], [14.1624, 15.1624]]]]),
rtol=1e-4,
)
np.testing.assert_allclose(item[1], np.array([[[0, 1], [2, 4], [0, 2]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5)
@@ -111,7 +113,9 @@ def test_loading_dict(self):
np.testing.assert_equal(item[0]["image"].shape, (2, 1, 2, 2))
np.testing.assert_allclose(
item[0]["image"],
- np.array([[[[7.6533, 8.6533], [11.6533, 12.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]),
+ np.array(
+ [[[[7.723618, 8.723618], [11.723618, 12.723618]]], [[[10.7175, 11.7175], [14.7175, 15.7175]]]]
+ ),
rtol=1e-3,
)
np.testing.assert_allclose(
diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py
index 8a105afcd27..03b33147dd8 100644
--- a/tests/test_grid_patch.py
+++ b/tests/test_grid_patch.py
@@ -14,8 +14,9 @@
import numpy as np
from parameterized import parameterized
+from monai.data import MetaTensor, set_track_meta
from monai.transforms.spatial.array import GridPatch
-from tests.utils import TEST_NDARRAYS, assert_allclose
+from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose
A = np.arange(16).repeat(3).reshape(4, 4, 3).transpose(2, 0, 1)
A11 = A[:, :2, :2]
@@ -46,34 +47,69 @@
]
TEST_CASE_13 = [{"patch_size": (2, 2), "threshold": 50.0}, A, [A11]]
+TEST_CASE_MEAT_0 = [
+ {"patch_size": (2, 2)},
+ A,
+ [A11, A12, A21, A22],
+ [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}],
+]
+
+TEST_CASE_MEAT_1 = [
+ {"patch_size": (2, 2)},
+ MetaTensor(x=A, meta={"path": "path/to/file"}),
+ [A11, A12, A21, A22],
+ [
+ {"location": [0, 0], "path": "path/to/file"},
+ {"location": [0, 2], "path": "path/to/file"},
+ {"location": [2, 0], "path": "path/to/file"},
+ {"location": [2, 2], "path": "path/to/file"},
+ ],
+]
-TEST_SINGLE = []
+TEST_CASES = []
for p in TEST_NDARRAYS:
- TEST_SINGLE.append([p, *TEST_CASE_0])
- TEST_SINGLE.append([p, *TEST_CASE_1])
- TEST_SINGLE.append([p, *TEST_CASE_2])
- TEST_SINGLE.append([p, *TEST_CASE_3])
- TEST_SINGLE.append([p, *TEST_CASE_4])
- TEST_SINGLE.append([p, *TEST_CASE_5])
- TEST_SINGLE.append([p, *TEST_CASE_6])
- TEST_SINGLE.append([p, *TEST_CASE_7])
- TEST_SINGLE.append([p, *TEST_CASE_8])
- TEST_SINGLE.append([p, *TEST_CASE_9])
- TEST_SINGLE.append([p, *TEST_CASE_10])
- TEST_SINGLE.append([p, *TEST_CASE_11])
- TEST_SINGLE.append([p, *TEST_CASE_12])
- TEST_SINGLE.append([p, *TEST_CASE_13])
+ TEST_CASES.append([p, *TEST_CASE_0])
+ TEST_CASES.append([p, *TEST_CASE_1])
+ TEST_CASES.append([p, *TEST_CASE_2])
+ TEST_CASES.append([p, *TEST_CASE_3])
+ TEST_CASES.append([p, *TEST_CASE_4])
+ TEST_CASES.append([p, *TEST_CASE_5])
+ TEST_CASES.append([p, *TEST_CASE_6])
+ TEST_CASES.append([p, *TEST_CASE_7])
+ TEST_CASES.append([p, *TEST_CASE_8])
+ TEST_CASES.append([p, *TEST_CASE_9])
+ TEST_CASES.append([p, *TEST_CASE_10])
+ TEST_CASES.append([p, *TEST_CASE_11])
+ TEST_CASES.append([p, *TEST_CASE_12])
+ TEST_CASES.append([p, *TEST_CASE_13])
class TestGridPatch(unittest.TestCase):
- @parameterized.expand(TEST_SINGLE)
+ @parameterized.expand(TEST_CASES)
def test_grid_patch(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = GridPatch(**input_parameters)
- output = list(splitter(input_image))
+ output = splitter(input_image)
self.assertEqual(len(output), len(expected))
for output_patch, expected_patch in zip(output, expected):
- assert_allclose(output_patch[0], expected_patch, type_test=False)
+ assert_allclose(output_patch, expected_patch, type_test=False)
+
+ @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1])
+ @SkipIfBeforePyTorchVersion((1, 9, 1))
+ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta):
+ set_track_meta(True)
+ splitter = GridPatch(**input_parameters)
+ output = splitter(image)
+ self.assertEqual(len(output), len(expected))
+ if "path" in expected_meta[0]:
+ self.assertTrue(output.meta["path"] == expected_meta[0]["path"])
+ for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):
+ assert_allclose(output_patch, expected_patch, type_test=False)
+ self.assertTrue(isinstance(output_patch, MetaTensor))
+ self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"])
+ self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:]))
+ if "path" in expected_meta[0]:
+ self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"])
if __name__ == "__main__":
diff --git a/tests/test_grid_patchd.py b/tests/test_grid_patchd.py
index 8f1e238b425..0f1bea5f8a9 100644
--- a/tests/test_grid_patchd.py
+++ b/tests/test_grid_patchd.py
@@ -74,10 +74,10 @@ def test_grid_patchd(self, in_type, input_parameters, image_dict, expected):
if k == image_key:
input_dict[k] = in_type(v)
splitter = GridPatchd(keys=image_key, **input_parameters)
- output = list(splitter(input_dict))
- self.assertEqual(len(output), len(expected))
- for output_patch, expected_patch in zip(output, expected):
- assert_allclose(output_patch[image_key], expected_patch, type_test=False)
+ output = splitter(input_dict)
+ self.assertEqual(len(output[image_key]), len(expected))
+ for output_patch, expected_patch in zip(output[image_key], expected):
+ assert_allclose(output_patch, expected_patch, type_test=False)
if __name__ == "__main__":
diff --git a/tests/test_handler_checkpoint_loader.py b/tests/test_handler_checkpoint_loader.py
index 23316022344..bdd44996876 100644
--- a/tests/test_handler_checkpoint_loader.py
+++ b/tests/test_handler_checkpoint_loader.py
@@ -17,6 +17,7 @@
from ignite.engine import Engine, Events
from monai.handlers import CheckpointLoader, CheckpointSaver
+from tests.utils import assert_allclose
class TestHandlerCheckpointLoader(unittest.TestCase):
@@ -42,7 +43,7 @@ def check_epoch(engine: Engine):
self.assertEqual(engine.state.epoch, 5)
engine2.run([0] * 8, max_epochs=8)
- torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1]))
+ assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1]))
# test bad case with max_epochs smaller than current epoch
engine3 = Engine(lambda e, b: None)
@@ -73,7 +74,7 @@ def test_two_save_one_load(self):
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine)
engine.run([0] * 8, max_epochs=1)
- torch.testing.assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1]))
+ assert_allclose(net2.state_dict()["weight"], torch.tensor([0.1]))
def test_save_single_device_load_multi_devices(self):
net1 = torch.nn.PReLU()
@@ -93,7 +94,7 @@ def test_save_single_device_load_multi_devices(self):
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=True).attach(engine)
engine.run([0] * 8, max_epochs=1)
- torch.testing.assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1]))
+ assert_allclose(net2.state_dict()["module.weight"].cpu(), torch.tensor([0.1]))
def test_partial_under_load(self):
net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
@@ -115,7 +116,7 @@ def test_partial_under_load(self):
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine)
engine.run([0] * 8, max_epochs=1)
- torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))
+ assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))
def test_partial_over_load(self):
net1 = torch.nn.Sequential(*[torch.nn.PReLU()])
@@ -137,7 +138,7 @@ def test_partial_over_load(self):
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False).attach(engine)
engine.run([0] * 8, max_epochs=1)
- torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))
+ assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))
def test_strict_shape(self):
net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)])
@@ -168,7 +169,7 @@ def test_strict_shape(self):
strict_shape=False,
).attach(engine)
engine.run([0] * 8, max_epochs=1)
- torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2]))
+ assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2]))
# test whether `opt2` had been skipped when loading with `strict_shape=False`,
# it should have 2 items in `params`(0.weight and 1.weight) while the checkpoint has 1 item(0.weight)
self.assertEqual(len(opt1.state_dict()["param_groups"][0]["params"]), 1)
diff --git a/tests/test_handler_confusion_matrix.py b/tests/test_handler_confusion_matrix.py
index 5bddef26af7..ee6f3cd6815 100644
--- a/tests/test_handler_confusion_matrix.py
+++ b/tests/test_handler_confusion_matrix.py
@@ -17,6 +17,7 @@
from parameterized import parameterized
from monai.handlers import ConfusionMatrix
+from tests.utils import assert_allclose
TEST_CASE_1 = [{"include_background": True, "save_details": False, "metric_name": "f1"}, 0.75]
TEST_CASE_2 = [{"include_background": False, "save_details": False, "metric_name": "ppv"}, 1.0]
@@ -60,7 +61,7 @@ def test_compute(self, input_params, expected_avg):
metric.update([y_pred, y])
avg_metric = metric.compute()
- torch.testing.assert_allclose(avg_metric, expected_avg)
+ assert_allclose(avg_metric, expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
@parameterized.expand([TEST_CASE_SEG_1])
def test_compute_seg(self, input_params, expected_avg):
diff --git a/tests/test_handler_confusion_matrix_dist.py b/tests/test_handler_confusion_matrix_dist.py
index 325a7990981..511e84d22ae 100644
--- a/tests/test_handler_confusion_matrix_dist.py
+++ b/tests/test_handler_confusion_matrix_dist.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import numpy as np
diff --git a/tests/test_handler_decollate_batch.py b/tests/test_handler_decollate_batch.py
index 1a43ae295bc..757708ea2b0 100644
--- a/tests/test_handler_decollate_batch.py
+++ b/tests/test_handler_decollate_batch.py
@@ -16,6 +16,7 @@
from monai.engines import SupervisedEvaluator
from monai.handlers import DecollateBatch, PostProcessing
from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd
+from tests.utils import assert_allclose
class TestHandlerDecollateBatch(unittest.TestCase):
@@ -53,7 +54,7 @@ def test_compute(self):
expected = torch.tensor([[[[1.0], [1.0]], [[0.0], [0.0]]]])
for o, e in zip(engine.state.output, expected):
- torch.testing.assert_allclose(o["pred"], e)
+ assert_allclose(o["pred"], e)
filename = o.get("filename_bak")
if filename is not None:
self.assertEqual(filename, "test2")
diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py
index 0350ba62fb6..e3bc3411b92 100644
--- a/tests/test_handler_garbage_collector.py
+++ b/tests/test_handler_garbage_collector.py
@@ -24,7 +24,6 @@
Events, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
-
TEST_CASE_0 = [[0, 1, 2], "epoch"]
TEST_CASE_1 = [[0, 1, 2], "iteration"]
diff --git a/tests/test_handler_hausdorff_distance.py b/tests/test_handler_hausdorff_distance.py
index 7e38f0ad56e..1fe9b2e4a30 100644
--- a/tests/test_handler_hausdorff_distance.py
+++ b/tests/test_handler_hausdorff_distance.py
@@ -17,6 +17,7 @@
from ignite.engine import Engine
from monai.handlers import HausdorffDistance
+from tests.utils import assert_allclose
def create_spherical_seg_3d(
@@ -102,7 +103,7 @@ def _val_func(engine, batch):
hd_metric.update([y_pred, y])
y_pred, y = TEST_SAMPLE_2
hd_metric.update([y_pred, y])
- torch.testing.assert_allclose(hd_metric.compute().float(), torch.tensor([10.0, 0.0]))
+ assert_allclose(hd_metric.compute().float(), torch.tensor([10.0, 0.0]))
if __name__ == "__main__":
diff --git a/tests/test_handler_logfile.py b/tests/test_handler_logfile.py
new file mode 100644
index 00000000000..b67bf63a5b7
--- /dev/null
+++ b/tests/test_handler_logfile.py
@@ -0,0 +1,83 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+import unittest
+
+import torch
+
+from monai.utils import optional_import
+from tests.utils import SkipIfNoModule
+
+try:
+ _, has_ignite = optional_import("ignite")
+ from ignite.engine import Engine
+
+ from monai.handlers import LogfileHandler
+except ImportError:
+ has_ignite = False
+
+
+class TestHandlerLogfile(unittest.TestCase):
+ def setUp(self):
+ if has_ignite:
+ # set up engine
+ def _train_func(engine, batch):
+ return torch.tensor(0.0)
+
+ self.engine = Engine(_train_func)
+
+ logger = self.engine.logger
+
+ # remove all other handlers to prevent output
+ while logger is not None:
+ del logger.handlers[:]
+ logger = logger.parent
+
+ @SkipIfNoModule("ignite")
+ def test_logfile(self):
+ with tempfile.TemporaryDirectory() as tempdir:
+ handler = LogfileHandler(output_dir=tempdir)
+ handler.attach(self.engine)
+
+ self.engine.run(range(3))
+
+ self.assertTrue(os.path.isfile(os.path.join(tempdir, "log.txt")))
+
+ @SkipIfNoModule("ignite")
+ def test_filename(self):
+ filename = "something_else.txt"
+
+ with tempfile.TemporaryDirectory() as tempdir:
+
+ handler = LogfileHandler(output_dir=tempdir, filename=filename)
+ handler.attach(self.engine)
+
+ self.engine.run(range(3))
+
+ self.assertTrue(os.path.isfile(os.path.join(tempdir, filename)))
+
+ @SkipIfNoModule("ignite")
+ def test_createdir(self):
+ with tempfile.TemporaryDirectory() as tempdir:
+ output_dir = os.path.join(tempdir, "new_dir")
+
+ handler = LogfileHandler(output_dir=output_dir)
+ handler.attach(self.engine)
+
+ self.engine.run(range(3))
+
+ self.assertTrue(os.path.isfile(os.path.join(output_dir, "log.txt")))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_handler_mean_dice.py b/tests/test_handler_mean_dice.py
index f309c7e693d..88eb4fbdcd9 100644
--- a/tests/test_handler_mean_dice.py
+++ b/tests/test_handler_mean_dice.py
@@ -16,6 +16,7 @@
from parameterized import parameterized
from monai.handlers import MeanDice, from_engine
+from tests.utils import assert_allclose
TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)]
TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 0.66666, (4, 1)]
@@ -32,6 +33,7 @@ class TestHandlerMeanDice(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_compute(self, input_params, expected_avg, details_shape):
dice_metric = MeanDice(**input_params)
+
# set up engine
def _val_func(engine, batch):
@@ -51,7 +53,7 @@ def _val_func(engine, batch):
engine.fire_event(Events.ITERATION_COMPLETED)
engine.fire_event(Events.EPOCH_COMPLETED)
- torch.testing.assert_allclose(engine.state.metrics["mean_dice"], expected_avg)
+ assert_allclose(engine.state.metrics["mean_dice"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
self.assertTupleEqual(tuple(engine.state.metric_details["mean_dice"].shape), details_shape)
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
diff --git a/tests/test_handler_mean_iou.py b/tests/test_handler_mean_iou.py
new file mode 100644
index 00000000000..fdd4a5d04db
--- /dev/null
+++ b/tests/test_handler_mean_iou.py
@@ -0,0 +1,74 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from ignite.engine import Engine, Events
+from parameterized import parameterized
+
+from monai.handlers import MeanIoUHandler, from_engine
+from tests.utils import assert_allclose
+
+TEST_CASE_1 = [{"include_background": True, "output_transform": from_engine(["pred", "label"])}, 0.75, (4, 2)]
+TEST_CASE_2 = [{"include_background": False, "output_transform": from_engine(["pred", "label"])}, 2 / 3, (4, 1)]
+TEST_CASE_3 = [
+ {"reduction": "mean_channel", "output_transform": from_engine(["pred", "label"])},
+ torch.Tensor([1.0, 0.0, 1.0, 1.0]),
+ (4, 2),
+]
+
+
+class TestHandlerMeanIoU(unittest.TestCase):
+ # TODO test multi node averaged iou
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
+ def test_compute(self, input_params, expected_avg, details_shape):
+ iou_metric = MeanIoUHandler(**input_params)
+
+ # set up engine
+
+ def _val_func(engine, batch):
+ pass
+
+ engine = Engine(_val_func)
+ iou_metric.attach(engine=engine, name="mean_iou")
+ # test input a list of channel-first tensor
+ y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
+ y = torch.Tensor([[[0], [1]], [[0], [1]]])
+ engine.state.output = {"pred": y_pred, "label": y}
+ engine.fire_event(Events.ITERATION_COMPLETED)
+
+ y_pred = [torch.Tensor([[0], [1]]), torch.Tensor([[1], [0]])]
+ y = torch.Tensor([[[0], [1]], [[1], [0]]])
+ engine.state.output = {"pred": y_pred, "label": y}
+ engine.fire_event(Events.ITERATION_COMPLETED)
+
+ engine.fire_event(Events.EPOCH_COMPLETED)
+ assert_allclose(engine.state.metrics["mean_iou"], expected_avg)
+ self.assertTupleEqual(tuple(engine.state.metric_details["mean_iou"].shape), details_shape)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_shape_mismatch(self, input_params, _expected_avg, _details_shape):
+ iou_metric = MeanIoUHandler(**input_params)
+ with self.assertRaises((AssertionError, ValueError)):
+ y_pred = torch.Tensor([[0, 1], [1, 0]])
+ y = torch.ones((3, 30))
+ iou_metric.update([y_pred, y])
+
+ with self.assertRaises((AssertionError, ValueError)):
+ y_pred = torch.Tensor([[0, 1], [1, 0]])
+ y = torch.ones((8, 30))
+ iou_metric.update([y_pred, y])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py
index a92fdf93d37..426d99c2237 100644
--- a/tests/test_handler_metrics_saver_dist.py
+++ b/tests/test_handler_metrics_saver_dist.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import csv
import os
import tempfile
diff --git a/tests/test_handler_nvtx.py b/tests/test_handler_nvtx.py
index eeca15ea6f3..29a1b8e4fb1 100644
--- a/tests/test_handler_nvtx.py
+++ b/tests/test_handler_nvtx.py
@@ -19,6 +19,7 @@
from monai.handlers import StatsHandler, from_engine
from monai.handlers.nvtx_handlers import MarkHandler, RangeHandler, RangePopHandler, RangePushHandler
from monai.utils import optional_import
+from tests.utils import assert_allclose
_, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
@@ -83,7 +84,7 @@ def test_compute(self, data, expected):
# Get the output from the engine
output = engine.state.output[0]
- torch.testing.assert_allclose(output["pred"], expected)
+ assert_allclose(output["pred"], expected)
if __name__ == "__main__":
diff --git a/tests/test_handler_panoptic_quality.py b/tests/test_handler_panoptic_quality.py
new file mode 100644
index 00000000000..a852ee929ab
--- /dev/null
+++ b/tests/test_handler_panoptic_quality.py
@@ -0,0 +1,86 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from ignite.engine import Engine, Events
+from parameterized import parameterized
+
+from monai.handlers import PanopticQuality, from_engine
+from tests.utils import SkipIfNoModule, assert_allclose
+
+sample_1_pred = torch.as_tensor(
+ [[[0, 1, 1, 1], [0, 0, 5, 5], [2, 0, 3, 3], [2, 2, 2, 0]], [[0, 1, 1, 1], [0, 0, 0, 0], [2, 0, 3, 3], [4, 2, 2, 0]]]
+)
+
+sample_1_gt = torch.as_tensor(
+ [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [0, 0, 1, 1], [2, 0, 3, 3], [4, 4, 4, 3]]]
+)
+
+sample_2_pred = torch.as_tensor(
+ [[[3, 1, 1, 1], [3, 1, 1, 4], [3, 1, 4, 4], [3, 2, 2, 4]], [[0, 1, 1, 1], [2, 2, 2, 2], [2, 0, 0, 3], [4, 2, 2, 3]]]
+)
+
+sample_2_gt = torch.as_tensor(
+ [[[0, 6, 6, 6], [1, 0, 5, 5], [1, 0, 3, 3], [1, 3, 2, 0]], [[0, 1, 1, 1], [2, 1, 1, 3], [2, 0, 0, 3], [4, 2, 2, 3]]]
+)
+
+TEST_CASE_1 = [{"num_classes": 4, "output_transform": from_engine(["pred", "label"])}, [0.6667, 0.1538, 0.6667, 0.5714]]
+TEST_CASE_2 = [
+ {
+ "num_classes": 5,
+ "output_transform": from_engine(["pred", "label"]),
+ "metric_name": "rq",
+ "match_iou_threshold": 0.3,
+ },
+ [0.6667, 0.7692, 0.8889, 0.5714, 0.0000],
+]
+TEST_CASE_3 = [
+ {
+ "num_classes": 5,
+ "reduction": "mean",
+ "output_transform": from_engine(["pred", "label"]),
+ "metric_name": "SQ",
+ "match_iou_threshold": 0.2,
+ },
+ 0.8235,
+]
+
+
+@SkipIfNoModule("scipy.optimize")
+class TestHandlerPanopticQuality(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
+ def test_compute(self, input_params, expected_avg):
+ metric = PanopticQuality(**input_params)
+ # set up engine
+
+ def _val_func(engine, batch):
+ pass
+
+ engine = Engine(_val_func)
+ metric.attach(engine=engine, name="panoptic_quality")
+ # test input a list of channel-first tensor
+ y_pred = [sample_1_pred, sample_2_pred]
+ y = [sample_1_gt, sample_2_gt]
+ engine.state.output = {"pred": y_pred, "label": y}
+ engine.fire_event(Events.ITERATION_COMPLETED)
+ y_pred = [sample_1_pred, sample_1_pred]
+ y = [sample_1_gt, sample_1_gt]
+ engine.state.output = {"pred": y_pred, "label": y}
+ engine.fire_event(Events.ITERATION_COMPLETED)
+
+ engine.fire_event(Events.EPOCH_COMPLETED)
+ assert_allclose(engine.state.metrics["panoptic_quality"], expected_avg, atol=1e-4, rtol=1e-4, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py
index 72742f19566..22bf046b83c 100644
--- a/tests/test_handler_parameter_scheduler.py
+++ b/tests/test_handler_parameter_scheduler.py
@@ -11,11 +11,11 @@
import unittest
-import torch
from ignite.engine import Engine, Events
from torch.nn import Module
from monai.handlers.parameter_scheduler import ParamSchedulerHandler
+from tests.utils import assert_allclose
class ToyNet(Module):
@@ -46,7 +46,7 @@ def test_linear_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
- torch.testing.assert_allclose(net.get_value(), 0)
+ assert_allclose(net.get_value(), 0)
# Testing linear increase
net = ToyNet(value=-1)
@@ -59,7 +59,7 @@ def test_linear_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=3)
- torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)
+ assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)
# Testing max_value
net = ToyNet(value=-1)
@@ -72,7 +72,7 @@ def test_linear_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
- torch.testing.assert_allclose(net.get_value(), 10)
+ assert_allclose(net.get_value(), 10)
def test_exponential_scheduler(self):
net = ToyNet(value=-1)
@@ -85,7 +85,7 @@ def test_exponential_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
- torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
def test_step_scheduler(self):
net = ToyNet(value=-1)
@@ -98,7 +98,7 @@ def test_step_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
- torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
def test_multistep_scheduler(self):
net = ToyNet(value=-1)
@@ -111,7 +111,7 @@ def test_multistep_scheduler(self):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
- torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
def test_custom_scheduler(self):
def custom_logic(initial_value, gamma, current_step):
@@ -127,7 +127,7 @@ def custom_logic(initial_value, gamma, current_step):
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
- torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
+ assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
if __name__ == "__main__":
diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py
index 89087e17655..10c7ac4a8b4 100644
--- a/tests/test_handler_post_processing.py
+++ b/tests/test_handler_post_processing.py
@@ -17,6 +17,7 @@
from monai.engines import SupervisedEvaluator
from monai.handlers import PostProcessing
from monai.transforms import Activationsd, AsDiscreted, Compose, CopyItemsd
+from tests.utils import assert_allclose
# test lambda function as `transform`
TEST_CASE_1 = [{"transform": lambda x: dict(pred=x["pred"] + 1.0)}, False, torch.tensor([[[[1.9975], [1.9997]]]])]
@@ -58,13 +59,13 @@ def test_compute(self, input_params, decollate, expected):
if isinstance(engine.state.output, list):
# test decollated list items
for o, e in zip(engine.state.output, expected):
- torch.testing.assert_allclose(o["pred"], e)
+ assert_allclose(o["pred"], e, atol=1e-4, rtol=1e-4, type_test=False)
filename = o.get("filename_bak")
if filename is not None:
self.assertEqual(filename, "test2")
else:
# test batch data
- torch.testing.assert_allclose(engine.state.output["pred"], expected)
+ assert_allclose(engine.state.output["pred"], expected, atol=1e-4, rtol=1e-4, type_test=False)
if __name__ == "__main__":
diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py
index dbb4f57b620..a968e7dea08 100644
--- a/tests/test_handler_prob_map_producer.py
+++ b/tests/test_handler_prob_map_producer.py
@@ -16,14 +16,13 @@
import torch
from ignite.engine import Engine
from parameterized import parameterized
-from torch.utils.data import DataLoader
-from monai.data.dataset import Dataset
+from monai.data import DataLoader, Dataset, MetaTensor
from monai.engines import Evaluator
from monai.handlers import ProbMapProducer, ValidationHandler
from monai.utils.enums import ProbMapKeys
-TEST_CASE_0 = ["temp_image_inference_output_1", 2]
+TEST_CASE_0 = ["temp_image_inference_output_1", 1]
TEST_CASE_1 = ["temp_image_inference_output_2", 9]
TEST_CASE_2 = ["temp_image_inference_output_3", 100]
@@ -35,32 +34,32 @@ def __init__(self, name, size):
{
"image": name,
ProbMapKeys.COUNT.value: size,
- ProbMapKeys.SIZE.value: np.array([size, size]),
+ ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),
ProbMapKeys.LOCATION.value: np.array([i, i + 1]),
}
- for i in range(size - 1)
+ for i in range(size)
]
)
self.image_data = [
{
ProbMapKeys.NAME.value: name,
ProbMapKeys.COUNT.value: size,
- ProbMapKeys.SIZE.value: np.array([size, size]),
+ ProbMapKeys.SIZE.value: np.array([size + 1, size + 1]),
}
]
def __getitem__(self, index):
- return {
- "image": np.zeros((3, 2, 2)),
+
+ image = np.ones((3, 2, 2)) * index
+ metadata = {
ProbMapKeys.COUNT.value: self.data[index][ProbMapKeys.COUNT.value],
- "metadata": {
- ProbMapKeys.NAME.value: self.data[index]["image"],
- ProbMapKeys.SIZE.value: self.data[index][ProbMapKeys.SIZE.value],
- ProbMapKeys.LOCATION.value: self.data[index][ProbMapKeys.LOCATION.value],
- },
- "pred": index + 1,
+ ProbMapKeys.NAME.value: self.data[index]["image"],
+ ProbMapKeys.SIZE.value: self.data[index][ProbMapKeys.SIZE.value],
+ ProbMapKeys.LOCATION.value: self.data[index][ProbMapKeys.LOCATION.value],
}
+ return {"image": MetaTensor(x=image, meta=metadata), "pred": index + 1}
+
class TestEvaluator(Evaluator):
def _iteration(self, engine, batchdata):
@@ -72,10 +71,11 @@ class TestHandlerProbMapGenerator(unittest.TestCase):
def test_prob_map_generator(self, name, size):
# set up dataset
dataset = TestDataset(name, size)
- data_loader = DataLoader(dataset, batch_size=1)
+ batch_size = 2
+ data_loader = DataLoader(dataset, batch_size=batch_size)
# set up engine
- def inference(enging, batch):
+ def inference(engine, batch):
pass
engine = Engine(inference)
@@ -84,7 +84,9 @@ def inference(enging, batch):
output_dir = os.path.join(os.path.dirname(__file__), "testing_data")
prob_map_gen = ProbMapProducer(output_dir=output_dir)
- evaluator = TestEvaluator(torch.device("cpu:0"), data_loader, size, val_handlers=[prob_map_gen])
+ evaluator = TestEvaluator(
+ torch.device("cpu:0"), data_loader, np.ceil(size / batch_size), val_handlers=[prob_map_gen]
+ )
# set up validation handler
validation = ValidationHandler(interval=1, validator=None)
@@ -94,8 +96,8 @@ def inference(enging, batch):
engine.run(data_loader)
prob_map = np.load(os.path.join(output_dir, name + ".npy"))
- self.assertListEqual(np.vstack(prob_map.nonzero()).T.tolist(), [[i, i + 1] for i in range(size - 1)])
- self.assertListEqual(prob_map[prob_map.nonzero()].tolist(), [i + 1 for i in range(size - 1)])
+ self.assertListEqual(np.vstack(prob_map.nonzero()).T.tolist(), [[i, i + 1] for i in range(size)])
+ self.assertListEqual(prob_map[prob_map.nonzero()].tolist(), [i + 1 for i in range(size)])
if __name__ == "__main__":
diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py
index 5113911d7c1..994cbe139b7 100644
--- a/tests/test_handler_rocauc_dist.py
+++ b/tests/test_handler_rocauc_dist.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import numpy as np
diff --git a/tests/test_handler_surface_distance.py b/tests/test_handler_surface_distance.py
index c9901819981..6d245693ac5 100644
--- a/tests/test_handler_surface_distance.py
+++ b/tests/test_handler_surface_distance.py
@@ -17,6 +17,7 @@
from ignite.engine import Engine
from monai.handlers import SurfaceDistance
+from tests.utils import assert_allclose
def create_spherical_seg_3d(
@@ -102,7 +103,9 @@ def _val_func(engine, batch):
sur_metric.update([y_pred, y])
y_pred, y = TEST_SAMPLE_2
sur_metric.update([y_pred, y])
- torch.testing.assert_allclose(sur_metric.compute().float(), torch.tensor([4.1713, 0.0000]))
+ assert_allclose(
+ sur_metric.compute().float(), torch.tensor([4.1713, 0.0000]), atol=1e-4, rtol=1e-4, type_test=False
+ )
if __name__ == "__main__":
diff --git a/tests/test_hovernet.py b/tests/test_hovernet.py
index 568aeb04dcb..389bb8c10ff 100644
--- a/tests/test_hovernet.py
+++ b/tests/test_hovernet.py
@@ -15,84 +15,187 @@
from parameterized import parameterized
from monai.networks import eval_mode
-from monai.networks.nets import HoverNet
+from monai.networks.nets import HoVerNet
from tests.utils import test_script_save
device = "cuda" if torch.cuda.is_available() else "cpu"
-TEST_CASE_0 = [ # fast mode, batch 16
- {"out_classes": 5, "mode": HoverNet.Mode.FAST},
+TEST_CASE_0 = [ # fast mode
+ {"out_classes": 5, "mode": HoVerNet.Mode.FAST},
(1, 3, 256, 256),
- {
- "nucleus_prediction": (1, 2, 164, 164),
- "type_prediction": (1, 5, 164, 164),
- "horizonal_vertical": (1, 2, 164, 164),
- },
+ {HoVerNet.Branch.NP: (1, 2, 164, 164), HoVerNet.Branch.NC: (1, 5, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},
]
-TEST_CASE_1 = [ # single channel 2D, batch 16
- {"mode": HoverNet.Mode.FAST},
- (1, 3, 256, 256),
- {"nucleus_prediction": (1, 2, 164, 164), "horizonal_vertical": (1, 2, 164, 164)},
-]
-
-TEST_CASE_2 = [ # single channel 3D, batch 16
- {"mode": HoverNet.Mode.ORIGINAL},
+TEST_CASE_1 = [ # original mode
+ {"out_classes": 6, "mode": HoVerNet.Mode.ORIGINAL},
(1, 3, 270, 270),
- {"nucleus_prediction": (1, 2, 80, 80), "horizonal_vertical": (1, 2, 80, 80)},
+ {HoVerNet.Branch.NP: (1, 2, 80, 80), HoVerNet.Branch.NC: (1, 6, 80, 80), HoVerNet.Branch.HV: (1, 2, 80, 80)},
]
-TEST_CASE_3 = [ # 4-channel 3D, batch 16
- {"out_classes": 6, "mode": HoverNet.Mode.ORIGINAL},
- (1, 3, 270, 270),
- {"nucleus_prediction": (1, 2, 80, 80), "type_prediction": (1, 6, 80, 80), "horizonal_vertical": (1, 2, 80, 80)},
+TEST_CASE_2 = [ # dropout
+ {"mode": HoVerNet.Mode.FAST, "dropout_prob": 0.5, "out_classes": 3},
+ (1, 3, 256, 256),
+ {HoVerNet.Branch.NP: (1, 2, 164, 164), HoVerNet.Branch.NC: (1, 3, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},
]
-TEST_CASE_4 = [ # 4-channel 3D, batch 16, batch normalization
- {"mode": HoverNet.Mode.FAST, "dropout_prob": 0.5},
+TEST_CASE_3 = [ # np_out_channels
+ {"mode": HoVerNet.Mode.FAST, "np_out_channels": 3, "out_classes": 2},
(1, 3, 256, 256),
- {"nucleus_prediction": (1, 2, 164, 164), "horizonal_vertical": (1, 2, 164, 164)},
+ {HoVerNet.Branch.NP: (1, 3, 164, 164), HoVerNet.Branch.NC: (1, 2, 164, 164), HoVerNet.Branch.HV: (1, 2, 164, 164)},
]
-CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]
+CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]
ILL_CASES = [
[{"out_classes": 6, "mode": 3}],
- [{"out_classes": 1000, "mode": HoverNet.Mode.ORIGINAL}],
- [{"out_classes": 1, "mode": HoverNet.Mode.ORIGINAL}],
- [{"out_classes": 6, "mode": HoverNet.Mode.ORIGINAL, "dropout_prob": 100}],
+ [{"out_classes": 6, "mode": "Wrong"}],
+ [{"out_classes": 1000, "mode": HoVerNet.Mode.ORIGINAL}],
+ [{"out_classes": 1, "mode": HoVerNet.Mode.ORIGINAL}],
+ [{"out_classes": 6, "mode": HoVerNet.Mode.ORIGINAL, "dropout_prob": 100}],
]
+def check_branch(branch, mode):
+ if mode == HoVerNet.Mode.ORIGINAL:
+ ksize = 5
+ else:
+ ksize = 3
+
+ if branch.decoderblock1.conva.kernel_size != (ksize, ksize):
+ return True
+ if branch.decoderblock1.convf.kernel_size != (1, 1):
+ return True
+ for block in branch.decoderblock1:
+ if isinstance(block, HoVerNet._DenseLayerDecoder):
+ if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize):
+ return True
+
+ if branch.decoderblock2.conva.kernel_size != (ksize, ksize):
+ return True
+ if branch.decoderblock2.convf.kernel_size != (1, 1):
+ return True
+
+ for block in branch.decoderblock2:
+ if isinstance(block, HoVerNet._DenseLayerDecoder):
+ if block.layers.conv1.kernel_size != (1, 1) or block.layers.conv2.kernel_size != (ksize, ksize):
+ return True
+
+ return False
+
+
+def check_output(out_block, mode):
+ if mode == HoVerNet.Mode.ORIGINAL:
+ ksize = 5
+ else:
+ ksize = 3
+
+ if out_block.decoderblock3.conva.kernel_size != (ksize, ksize) or out_block.decoderblock3.conva.stride != (1, 1):
+ return True
+ if out_block.decoderblock4.conv.kernel_size != (1, 1) or out_block.decoderblock4.conv.stride != (1, 1):
+ return True
+
+
+def check_kernels(net, mode):
+ # Check the Encoder blocks
+ for layer_num, res_block in enumerate(net.res_blocks):
+ for inner_num, layer in enumerate(res_block.layers):
+ if layer_num > 0 and inner_num == 0:
+ sz = 2
+ else:
+ sz = 1
+
+ if (
+ layer.layers.conv1.kernel_size != (1, 1)
+ or layer.layers.conv2.kernel_size != (3, 3)
+ or layer.layers.conv3.kernel_size != (1, 1)
+ ):
+ return True
+
+ if (
+ layer.layers.conv1.stride != (1, 1)
+ or layer.layers.conv2.stride != (sz, sz)
+ or layer.layers.conv3.stride != (1, 1)
+ ):
+ return True
+
+ sz2 = 1
+ if layer_num > 0:
+ sz2 = 2
+ if res_block.shortcut.kernel_size != (1, 1) or res_block.shortcut.stride != (sz2, sz2):
+ return True
+
+ if net.bottleneck.conv_bottleneck.kernel_size != (1, 1) or net.bottleneck.conv_bottleneck.stride != (1, 1):
+ return True
+
+ # Check HV Branch
+ if check_branch(net.horizontal_vertical.decoder_blocks, mode):
+ return True
+ if check_output(net.horizontal_vertical.output_features, mode):
+ return True
+
+ # Check NP Branch
+ if check_branch(net.nucleus_prediction.decoder_blocks, mode):
+ return True
+ if check_output(net.nucleus_prediction.output_features, mode):
+ return True
+
+ # Check NC Branch
+ if check_branch(net.type_prediction.decoder_blocks, mode):
+ return True
+ if check_output(net.type_prediction.output_features, mode):
+ return True
+
+
class TestHoverNet(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shapes):
- net = HoverNet(**input_param).to(device)
+ input_param["decoder_padding"] = False
+ net = HoVerNet(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
for item in result:
self.assertEqual(result[item].shape, expected_shapes[item])
- def test_script(self):
- net = HoverNet(mode=HoverNet.Mode.FAST)
- test_data = torch.randn(1, 3, 256, 256)
- test_script_save(net, test_data)
+ @parameterized.expand(CASES)
+ def test_decoder_padding_shape(self, input_param, input_shape, expected_shapes):
+ if input_param["mode"] == HoVerNet.Mode.FAST:
+ input_param["decoder_padding"] = True
+ net = HoVerNet(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device))
+ for item in result:
+ expected_shape = expected_shapes[item]
+ padding_expected_shape = list(expected_shape)
+ padding_expected_shape[2:] = input_shape[2:]
+ self.assertEqual(result[item].shape, tuple(padding_expected_shape))
+ else:
+ pass
- def test_script_without_running_stats(self):
- net = HoverNet(mode=HoverNet.Mode.FAST)
+ def test_script(self):
+ for padding_flag in [True, False]:
+ net = HoVerNet(mode=HoVerNet.Mode.FAST, decoder_padding=padding_flag)
test_data = torch.randn(1, 3, 256, 256)
test_script_save(net, test_data)
def test_ill_input_shape(self):
- net = HoverNet(mode=HoverNet.Mode.FAST)
+ net = HoVerNet(mode=HoVerNet.Mode.FAST)
with eval_mode(net):
with self.assertRaises(ValueError):
net.forward(torch.randn(1, 3, 270, 260))
+ def check_kernels_strides(self):
+ net = HoVerNet(mode=HoVerNet.Mode.FAST)
+ with eval_mode(net):
+ self.assertEqual(check_kernels(net, HoVerNet.Mode.FAST), False)
+
+ net = HoVerNet(mode=HoVerNet.Mode.ORIGINAL)
+ with eval_mode(net):
+ self.assertEqual(check_kernels(net, HoVerNet.Mode.ORIGINAL), False)
+
@parameterized.expand(ILL_CASES)
def test_ill_input_hyper_params(self, input_param):
with self.assertRaises(ValueError):
- _ = HoverNet(**input_param)
+ _ = HoVerNet(**input_param)
if __name__ == "__main__":
diff --git a/tests/test_hovernet_loss.py b/tests/test_hovernet_loss.py
new file mode 100644
index 00000000000..11c5fbdd65b
--- /dev/null
+++ b/tests/test_hovernet_loss.py
@@ -0,0 +1,186 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+from torch.nn import functional as F
+
+from monai.apps.pathology.losses import HoVerNetLoss
+from monai.transforms import GaussianSmooth, Rotate
+from monai.transforms.intensity.array import ComputeHoVerMaps
+from monai.utils.enums import HoVerNetBranch
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+s = 10e-8
+t = 1.0 - s
+H = 40
+W = 40
+N = 5
+B = 2
+
+
+class PrepareTestInputs:
+ def __init__(self, inputs):
+ self.inputs = {HoVerNetBranch.NP: inputs[1], HoVerNetBranch.HV: inputs[3]}
+ self.targets = {HoVerNetBranch.NP: inputs[0], HoVerNetBranch.HV: inputs[2]}
+
+ if len(inputs) > 4:
+ self.targets[HoVerNetBranch.NC] = inputs[4]
+ self.inputs[HoVerNetBranch.NC] = inputs[5]
+
+
+def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, width=5, rotation=0.0, smoothing=False):
+ t_g = torch.zeros((batch_size, height, width), dtype=torch.int64)
+ t_p = None
+ hv_g = torch.zeros((batch_size, 2, height, width))
+ hv_p = torch.zeros((batch_size, 2, height, width))
+
+ rad_min = 2
+ rad_max = min(max(height // 3, width // 3, rad_min), 5)
+
+ for b in range(batch_size):
+ random.seed(10 + b)
+ inst_map = torch.zeros((height, width), dtype=torch.int64)
+ for inst_id in range(1, num_objects + 1):
+ x = random.randint(rad_max, width - rad_max)
+ y = random.randint(rad_max, height - rad_max)
+ rad = random.randint(rad_min, rad_max)
+ spy, spx = np.ogrid[-x : height - x, -y : width - y]
+ circle = torch.tensor((spx * spx + spy * spy) <= rad * rad)
+
+ if num_classes > 1:
+ t_g[b, circle] = np.ceil(random.random() * num_classes)
+ else:
+ t_g[b, circle] = 1
+
+ inst_map[circle] = inst_id
+
+ hv_g[b] = ComputeHoVerMaps()(inst_map[None])
+ hv_g[b] = hv_g[b].squeeze(0)
+ if rotation > 0.0:
+ hv_p[b] = Rotate(angle=rotation, keep_size=True, mode="bilinear")(hv_g[b])
+
+ n_g = t_g > 0
+ if rotation == 0.0:
+ hv_p = hv_g * 0.99
+
+ # rotation of prediction needs to happen before one-hot encoding
+ if rotation > 0.0:
+ n_p = Rotate(angle=rotation, keep_size=True, mode="nearest")(n_g)
+ n_p = F.one_hot(n_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ if num_classes > 1:
+ t_p = Rotate(angle=rotation, keep_size=True, mode="nearest")(t_g)
+ t_p = F.one_hot(t_p.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ else:
+ t_g = None
+ else:
+ n_p = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ if num_classes > 1:
+ t_p = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ t_g = F.one_hot(t_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+ else:
+ t_g = None
+
+ n_g = F.one_hot(n_g.to(torch.int64)).to(torch.float32).permute(0, 3, 1, 2)
+
+ if smoothing:
+ n_p = GaussianSmooth()(n_p)
+ if num_classes > 1:
+ t_p = GaussianSmooth()(t_p)
+ hv_p = hv_p * 0.1
+ else:
+ n_p = torch.clamp(n_p, s, t)
+ if num_classes > 1:
+ t_p = torch.clamp(t_p, s, t)
+
+ # Apply log to emulate logits
+ if t_p is not None:
+ return n_g, n_p.log(), hv_g, hv_p, t_g, t_p.log()
+ else:
+ return n_g, n_p.log(), hv_g, hv_p
+
+
+inputs_test = [
+ PrepareTestInputs(test_shape_generator(height=H, width=W)),
+ PrepareTestInputs(test_shape_generator(num_classes=N, height=H, width=W)),
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W)),
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.15)),
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.2)),
+ PrepareTestInputs(test_shape_generator(num_classes=N, batch_size=B, height=H, width=W, rotation=0.25)),
+]
+
+TEST_CASE_0 = [ # batch size of 1, no type prediction
+ {"prediction": inputs_test[0].inputs, "target": inputs_test[0].targets},
+ 0.003,
+]
+
+TEST_CASE_1 = [ # batch size of 1, 2 classes with type prediction
+ {"prediction": inputs_test[1].inputs, "target": inputs_test[1].targets},
+ 0.2762,
+]
+
+TEST_CASE_2 = [ # batch size of 2, 2 classes with type prediction
+ {"prediction": inputs_test[2].inputs, "target": inputs_test[2].targets},
+ 0.4852,
+]
+
+TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction
+ {"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets},
+ 3.6169,
+]
+
+TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
+ {"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets},
+ 4.5079,
+]
+
+TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
+ {"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets},
+ 5.4663,
+]
+
+CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]
+
+ILL_CASES = [
+ [
+ {
+ "prediction": {"np": inputs_test[0].inputs[HoVerNetBranch.NP]},
+ "target": {
+ "np": inputs_test[0].targets[HoVerNetBranch.NP],
+ HoVerNetBranch.HV: inputs_test[0].targets[HoVerNetBranch.HV],
+ },
+ }
+ ]
+]
+
+
+class TestHoverNetLoss(unittest.TestCase):
+ @parameterized.expand(CASES)
+ def test_shape(self, input_param, expected_loss):
+ loss = HoVerNetLoss()
+ result = loss(**input_param).to(device)
+ self.assertAlmostEqual(float(result), expected_loss, places=2)
+
+ @parameterized.expand(ILL_CASES)
+ def test_ill_input_hyper_params(self, input_param):
+ with self.assertRaises(ValueError):
+ loss = HoVerNetLoss()
+ _ = loss(**input_param).to(device)
+
+
+if __name__ == "__main__":
+ unittest.main(argv=["first-arg-is-ignored"], exit=False)
diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py
index 80b7304ea24..0bc23dcebaf 100644
--- a/tests/test_image_rw.py
+++ b/tests/test_image_rw.py
@@ -23,10 +23,13 @@
from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage, SaveImage, moveaxis
-from monai.utils import MetaKeys, OptionalImportError
+from monai.utils import MetaKeys, OptionalImportError, optional_import
from tests.utils import TEST_NDARRAYS, assert_allclose
+_, has_itk = optional_import("itk", allow_namespace_pkg=True)
+
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSaveNifti(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@@ -62,6 +65,8 @@ def nifti_rw(self, test_data, reader, writer, dtype, resample=True):
_test_data = test_data[0]
if resample:
_test_data = moveaxis(_test_data, 0, 1)
+ assert_allclose(meta["qform_code"], 1, type_test=False)
+ assert_allclose(meta["sform_code"], 1, type_test=False)
assert_allclose(data, torch.as_tensor(_test_data))
@parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, "ITKWriter"]))
@@ -82,6 +87,7 @@ def test_4d(self, reader, writer):
self.nifti_rw(test_data, reader, writer, np.float16)
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSavePNG(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@@ -137,6 +143,7 @@ def test_1_new(self):
self.assertEqual(resolve_writer("new")[0](0), 1)
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadSaveNrrd(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
diff --git a/tests/test_integration_autorunner.py b/tests/test_integration_autorunner.py
new file mode 100644
index 00000000000..f62a8563b90
--- /dev/null
+++ b/tests/test_integration_autorunner.py
@@ -0,0 +1,145 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+import unittest
+from typing import Dict, List
+
+import nibabel as nib
+import numpy as np
+
+from monai.apps.auto3dseg import AutoRunner
+from monai.bundle.config_parser import ConfigParser
+from monai.data import create_test_image_3d
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick
+
+_, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
+_, has_nni = optional_import("nni")
+
+sim_datalist: Dict[str, List[Dict]] = {
+ "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+ "training": [
+ {"fold": 0, "image": "tr_image_001.fake.nii.gz", "label": "tr_label_001.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_002.fake.nii.gz", "label": "tr_label_002.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_003.fake.nii.gz", "label": "tr_label_003.fake.nii.gz"},
+ {"fold": 0, "image": "tr_image_004.fake.nii.gz", "label": "tr_label_004.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_005.fake.nii.gz", "label": "tr_label_005.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_006.fake.nii.gz", "label": "tr_label_006.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_007.fake.nii.gz", "label": "tr_label_007.fake.nii.gz"},
+ {"fold": 1, "image": "tr_image_008.fake.nii.gz", "label": "tr_label_008.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_009.fake.nii.gz", "label": "tr_label_009.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_010.fake.nii.gz", "label": "tr_label_010.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_011.fake.nii.gz", "label": "tr_label_011.fake.nii.gz"},
+ {"fold": 2, "image": "tr_image_012.fake.nii.gz", "label": "tr_label_012.fake.nii.gz"},
+ ],
+}
+
+train_param = {
+ "CUDA_VISIBLE_DEVICES": [0],
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+}
+
+pred_param = {"files_slices": slice(0, 1), "mode": "mean", "sigmoid": True}
+
+
+@skip_if_quick
+@SkipIfBeforePyTorchVersion((1, 9, 1))
+@unittest.skipIf(not has_tb, "no tensorboard summary writer")
+class TestAutoRunner(unittest.TestCase):
+ def setUp(self) -> None:
+ self.test_dir = tempfile.TemporaryDirectory()
+ test_path = self.test_dir.name
+
+ sim_dataroot = os.path.join(test_path, "dataroot")
+ if not os.path.isdir(sim_dataroot):
+ os.makedirs(sim_dataroot)
+
+ # Generate a fake dataset
+ for d in sim_datalist["testing"] + sim_datalist["training"]:
+ im, seg = create_test_image_3d(64, 64, 64, rad_max=10, num_seg_classes=1)
+ nib_image = nib.Nifti1Image(im, affine=np.eye(4))
+ image_fpath = os.path.join(sim_dataroot, d["image"])
+ nib.save(nib_image, image_fpath)
+
+ if "label" in d:
+ nib_image = nib.Nifti1Image(seg, affine=np.eye(4))
+ label_fpath = os.path.join(sim_dataroot, d["label"])
+ nib.save(nib_image, label_fpath)
+
+ sim_json_datalist = os.path.join(sim_dataroot, "sim_input.json")
+ ConfigParser.export_config_file(sim_datalist, sim_json_datalist)
+
+ data_src_cfg = os.path.join(test_path, "data_src_cfg.yaml")
+ data_src = {
+ "name": "sim_data",
+ "task": "segmentation",
+ "modality": "MRI",
+ "datalist": sim_json_datalist,
+ "dataroot": sim_dataroot,
+ "multigpu": False,
+ "class_names": ["label_class"],
+ }
+
+ ConfigParser.export_config_file(data_src, data_src_cfg)
+ self.data_src_cfg = data_src_cfg
+ self.test_path = test_path
+
+ @skip_if_no_cuda
+ def test_autorunner(self) -> None:
+ work_dir = os.path.join(self.test_path, "work_dir")
+ runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg)
+ runner.set_training_params(train_param) # 2 epochs
+ runner.set_num_fold(1)
+ with skip_if_downloading_fails():
+ runner.run()
+
+ @skip_if_no_cuda
+ @unittest.skipIf(not has_nni, "nni required")
+ def test_autorunner_hpo(self) -> None:
+ work_dir = os.path.join(self.test_path, "work_dir")
+ runner = AutoRunner(work_dir=work_dir, input=self.data_src_cfg, hpo=True)
+ hpo_param = {
+ "num_iterations": 8,
+ "num_iterations_per_validation": 4,
+ "num_images_per_batch": 2,
+ "num_epochs": 2,
+ "num_warmup_iterations": 4,
+ # below are to shorten the time for dints
+ "training#num_iterations": 8,
+ "training#num_iterations_per_validation": 4,
+ "training#num_images_per_batch": 2,
+ "training#num_epochs": 2,
+ "training#num_warmup_iterations": 4,
+ "searching#num_iterations": 8,
+ "searching#num_iterations_per_validation": 4,
+ "searching#num_images_per_batch": 2,
+ "searching#num_epochs": 2,
+ "searching#num_warmup_iterations": 4,
+ }
+ search_space = {"learning_rate": {"_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1]}}
+ runner.set_num_fold(1)
+ runner.set_nni_search_space(search_space)
+ runner.set_hpo_params(params=hpo_param)
+ with skip_if_downloading_fails():
+ runner.run()
+
+ def tearDown(self) -> None:
+ self.test_dir.cleanup()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py
index 6f20c55fe2b..1f6570eeeb9 100644
--- a/tests/test_integration_bundle_run.py
+++ b/tests/test_integration_bundle_run.py
@@ -12,7 +12,6 @@
import json
import os
import shutil
-import subprocess
import sys
import tempfile
import unittest
@@ -23,6 +22,7 @@
from monai.bundle import ConfigParser
from monai.transforms import LoadImage
+from tests.utils import command_line_tests
TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json"), (128, 128, 128)]
@@ -56,7 +56,7 @@ def test_tiny(self):
f,
)
cmd = ["coverage", "run", "-m", "monai.bundle", "run", "training", "--config_file", config_file]
- subprocess.check_call(cmd)
+ command_line_tests(cmd)
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, config_file, expected_shape):
@@ -96,7 +96,7 @@ def test_shape(self, config_file, expected_shape):
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
test_env = os.environ.copy()
print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
- subprocess.check_call(la + ["--args_file", def_args_file], env=test_env)
+ command_line_tests(la + ["--args_file", def_args_file])
loader = LoadImage(image_only=True)
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_seg.nii.gz")).shape, expected_shape)
@@ -104,7 +104,7 @@ def test_shape(self, config_file, expected_shape):
cmd = "-m fire monai.bundle.scripts run --runner_id evaluating"
cmd += f" --evaluator#amp False {override}"
la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file]
- subprocess.check_call(la, env=test_env)
+ command_line_tests(la)
self.assertTupleEqual(loader(os.path.join(tempdir, "image", "image_trans.nii.gz")).shape, expected_shape)
diff --git a/tests/test_integration_stn.py b/tests/test_integration_stn.py
index e655ff67556..5b9b22668af 100644
--- a/tests/test_integration_stn.py
+++ b/tests/test_integration_stn.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import numpy as np
diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py
index 0ef95d40057..342e70cc8e2 100644
--- a/tests/test_integration_workflows.py
+++ b/tests/test_integration_workflows.py
@@ -52,7 +52,7 @@
)
from monai.utils import optional_import, set_determinism
from tests.testing_data.integration_answers import test_integration_value
-from tests.utils import DistTestCase, TimedCall, pytorch_after, skip_if_quick
+from tests.utils import DistTestCase, TimedCall, assert_allclose, pytorch_after, skip_if_quick
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")
@@ -205,7 +205,12 @@ def _model_completed(self, engine):
)
trainer.run()
- return evaluator.state.best_metric
+ # test train and validation stats
+ train_stats = trainer.get_stats("output")
+ assert_allclose(train_stats["output"][0]["loss"], trainer.state.output[0]["loss"])
+ val_stats = evaluator.get_stats("metrics")
+
+ return val_stats["best_validation_metric"]
def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_workers=4):
diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py
index ff53851ce03..7dd05848bb5 100644
--- a/tests/test_integration_workflows_gan.py
+++ b/tests/test_integration_workflows_gan.py
@@ -22,11 +22,11 @@
import monai
from monai.data import create_test_image_2d
from monai.engines import GanTrainer
-from monai.engines.utils import GanKeys as Keys
from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler
from monai.networks import normal_init
from monai.networks.nets import Discriminator, Generator
from monai.transforms import AsChannelFirstd, Compose, LoadImaged, RandFlipd, ScaleIntensityd
+from monai.utils import GanKeys as Keys
from monai.utils import set_determinism
from tests.utils import DistTestCase, TimedCall, skip_if_quick
diff --git a/tests/test_invert.py b/tests/test_invert.py
index b867a646fa5..4bd648c2643 100644
--- a/tests/test_invert.py
+++ b/tests/test_invert.py
@@ -34,7 +34,7 @@
Spacing,
)
from monai.utils import set_determinism
-from tests.utils import make_nifti_image
+from tests.utils import assert_allclose, make_nifti_image
class TestInvert(unittest.TestCase):
@@ -73,7 +73,7 @@ def test_invert(self):
i = inverter(item)
self.assertTupleEqual(orig.shape[1:], (100, 100, 100))
# check the nearest interpolation mode
- torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
+ assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
# check labels match
reverted = i.detach().cpu().numpy().astype(np.int32)
diff --git a/tests/test_invertd.py b/tests/test_invertd.py
index fc4725d98b2..afa7958e9a9 100644
--- a/tests/test_invertd.py
+++ b/tests/test_invertd.py
@@ -35,7 +35,7 @@
Spacingd,
)
from monai.utils import set_determinism
-from tests.utils import make_nifti_image
+from tests.utils import assert_allclose, make_nifti_image
KEYS = ["image", "label"]
@@ -102,10 +102,10 @@ def test_invert(self):
self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100))
# check the nearest interpolation mode
i = item["image_inverted"]
- torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
+ assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
i = item["label_inverted"]
- torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
+ assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float))
self.assertTupleEqual(i.shape[1:], (100, 101, 107))
# check the case that different items use different interpolation mode to invert transforms
diff --git a/tests/test_k_space_spike_noised.py b/tests/test_k_space_spike_noised.py
index 7a6403655c7..03c99d15339 100644
--- a/tests/test_k_space_spike_noised.py
+++ b/tests/test_k_space_spike_noised.py
@@ -43,7 +43,7 @@ def get_data(im_shape, im_type):
create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d
ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)
ims = [im_type(im[None]) for im in ims]
- return {k: v for k, v in zip(KEYS, ims)}
+ return dict(zip(KEYS, ims))
@parameterized.expand(TESTS)
def test_same_result(self, im_shape, im_type):
diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py
index 80dbc1c51d9..a0e309f2d70 100644
--- a/tests/test_keep_largest_connected_component.py
+++ b/tests/test_keep_largest_connected_component.py
@@ -78,6 +78,8 @@ def to_onehot(x):
]
grid_5 = [[[0, 0, 1, 0, 0], [0, 1, 1, 1, 1], [1, 1, 1, 0, 0], [1, 1, 0, 1, 0], [1, 1, 0, 0, 1]]]
+grid_6 = [[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [1, 1, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]
+
TESTS = []
for p in TEST_NDARRAYS:
TESTS.append(
@@ -343,6 +345,37 @@ def to_onehot(x):
torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]),
]
)
+ # no connected regions
+ TESTS.append(["0 regions", {"num_components": 0}, p(grid_6), p(torch.zeros(1, 4, 7))])
+ # 1 connected region
+ TESTS.append(
+ [
+ "1 region",
+ {"num_components": 1},
+ p(grid_6),
+ p(
+ torch.tensor(
+ [[[0, 0, 1, 1, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 1, 0, 0, 0]]]
+ )
+ ),
+ ]
+ )
+ # 2 connected regions
+ TESTS.append(
+ [
+ "2 regions",
+ {"num_components": 2},
+ p(grid_6),
+ p(
+ torch.tensor(
+ [[[0, 0, 1, 1, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 1, 0, 1], [0, 0, 0, 1, 0, 0, 1]]]
+ )
+ ),
+ ]
+ )
+ # 3+ connected regions unchanged (as input has 3)
+ for num_connected in (3, 4):
+ TESTS.append([f"{num_connected} regions", {"num_components": num_connected}, p(grid_6), p(grid_6)])
class TestKeepLargestConnectedComponent(unittest.TestCase):
diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py
index b782f904410..42aa419b1d0 100644
--- a/tests/test_label_filter.py
+++ b/tests/test_label_filter.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
@@ -51,7 +50,6 @@
VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)])
-
ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError]
INVALID_CASES = [ITEST_CASE_1]
diff --git a/tests/test_label_filterd.py b/tests/test_label_filterd.py
index d53dc21faff..eea18d02780 100644
--- a/tests/test_label_filterd.py
+++ b/tests/test_label_filterd.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
@@ -51,7 +50,6 @@
VALID_TESTS.append(["filter_all", {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, p(grid_1), p(grid_1)])
-
ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError]
INVALID_CASES = [ITEST_CASE_1]
diff --git a/tests/test_label_quality_score.py b/tests/test_label_quality_score.py
new file mode 100644
index 00000000000..db31624a953
--- /dev/null
+++ b/tests/test_label_quality_score.py
@@ -0,0 +1,130 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.metrics import LabelQualityScore, label_quality_score
+
+_device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+# keep background, 1D Case
+TEST_CASE_1 = [ # y_pred (3, 1, 3), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),
+ "y": torch.tensor([[[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]], [[1.0, 1.0, 1.0]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "sum",
+ },
+ [0.0, 0.0, 0.0],
+]
+
+# keep background, 2D Case
+TEST_CASE_2 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "sum",
+ },
+ [0.0],
+]
+
+# keep background, 3D Case
+TEST_CASE_3 = [ # y_pred (1, 1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "sum",
+ },
+ [0.0],
+]
+
+# keep background, 2D Case
+TEST_CASE_4 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "sum",
+ },
+ [4.0],
+]
+
+TEST_CASE_5 = [ # y_pred (1, 1, 2, 2), expected out (0.0)
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[0.0, 0.0], [0.0, 0.0]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "mean",
+ },
+ [1.0],
+]
+
+# Spatial Map Test Case for 3D Case
+TEST_CASE_6 = [ # y_pred (1, 1, 2, 2, 2), expected out all (0.0) map of 2x2x2
+ {
+ "y_pred": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "y": torch.tensor([[[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "none",
+ },
+ [[[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]]],
+]
+
+# Spatial Map Test Case for 2D Case
+TEST_CASE_7 = [ # y_pred (1, 1, 2, 2)
+ {
+ "y_pred": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "y": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]], device=_device),
+ "include_background": True,
+ "scalar_reduction": "none",
+ },
+ [[[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]]],
+]
+
+
+class TestLabelQualityScore(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
+ def test_value(self, input_data, expected_value):
+ result = label_quality_score(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_6, TEST_CASE_7])
+ def test_spatial_case(self, input_data, expected_value):
+ result = label_quality_score(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
+ def test_value_class(self, input_data, expected_value):
+ vals = {}
+ vals["y_pred"] = input_data.pop("y_pred")
+ vals["y"] = input_data.pop("y")
+ comp_var = LabelQualityScore(**input_data)
+ result = comp_var(**vals)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_6, TEST_CASE_7])
+ def test_spatial_case_class(self, input_data, expected_value):
+ vals = {}
+ vals["y_pred"] = input_data.pop("y_pred")
+ vals["y"] = input_data.pop("y")
+ comp_var = LabelQualityScore(**input_data)
+ result = comp_var(**vals)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_lesion_froc.py b/tests/test_lesion_froc.py
index b135b3eaeb5..8c9f751b1e1 100644
--- a/tests/test_lesion_froc.py
+++ b/tests/test_lesion_froc.py
@@ -114,7 +114,6 @@ def prepare_test_data():
np.nan,
]
-
TEST_CASE_1 = [
{
"data": [
@@ -163,7 +162,6 @@ def prepare_test_data():
1.0,
]
-
TEST_CASE_4 = [
{
"data": [
@@ -196,7 +194,6 @@ def prepare_test_data():
0.5,
]
-
TEST_CASE_6 = [
{
"data": [
diff --git a/tests/test_lltm.py b/tests/test_lltm.py
index 7633c2fe346..877d2117672 100644
--- a/tests/test_lltm.py
+++ b/tests/test_lltm.py
@@ -15,7 +15,7 @@
from parameterized import parameterized
from monai.networks.layers import LLTM
-from tests.utils import SkipIfNoModule, is_tf32_env
+from tests.utils import SkipIfNoModule, assert_allclose, is_tf32_env
_rtol = 0.001 if is_tf32_env() else 0.0001
@@ -37,8 +37,8 @@ def test_value(self, input_param, expected_h, expected_c):
new_h, new_c = LLTM(**input_param)(x, (h, c))
(new_h.sum() + new_c.sum()).backward()
- torch.testing.assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04)
- torch.testing.assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04)
+ assert_allclose(new_h, expected_h, rtol=0.0001, atol=1e-04)
+ assert_allclose(new_c, expected_c, rtol=0.0001, atol=1e-04)
@parameterized.expand([TEST_CASE_1])
@SkipIfNoModule("monai._C")
@@ -52,8 +52,8 @@ def test_value_cuda(self, input_param, expected_h, expected_c):
new_h, new_c = lltm(x, (h, c))
(new_h.sum() + new_c.sum()).backward()
- torch.testing.assert_allclose(new_h, expected_h.to(device), rtol=_rtol, atol=0.001)
- torch.testing.assert_allclose(new_c, expected_c.to(device), rtol=_rtol, atol=0.001)
+ assert_allclose(new_h, expected_h.to(device), rtol=_rtol, atol=0.001)
+ assert_allclose(new_c, expected_c.to(device), rtol=_rtol, atol=0.001)
if __name__ == "__main__":
diff --git a/tests/test_load_image.py b/tests/test_load_image.py
index cc227021a2e..1db39a310b0 100644
--- a/tests/test_load_image.py
+++ b/tests/test_load_image.py
@@ -15,19 +15,23 @@
import unittest
from pathlib import Path
-import itk
import nibabel as nib
import numpy as np
import torch
from parameterized import parameterized
from PIL import Image
-from monai.data import ITKReader, NibabelReader, PydicomReader
+from monai.data import NibabelReader, PydicomReader
from monai.data.meta_obj import set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.transforms import LoadImage
+from monai.utils import optional_import
from tests.utils import assert_allclose
+itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
+ITKReader, _ = optional_import("monai.data", name="ITKReader", as_type="decorator")
+itk_uc, _ = optional_import("itk", name="UC", allow_namespace_pkg=True)
+
class _MiniReader:
"""a test case customised reader"""
@@ -67,35 +71,39 @@ def get_data(self, _obj):
TEST_CASE_5 = [{"reader": NibabelReader(mmap=False)}, ["test_image.nii.gz"], (128, 128, 128)]
-TEST_CASE_6 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)]
+TEST_CASE_6 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
-TEST_CASE_7 = [{"reader": ITKReader()}, ["test_image.nii.gz"], (128, 128, 128)]
+TEST_CASE_7 = [{"reader": ITKReader() if has_itk else "itkreader"}, ["test_image.nii.gz"], (128, 128, 128)]
TEST_CASE_8 = [
- {"reader": ITKReader()},
+ {"reader": ITKReader() if has_itk else "itkreader"},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(3, 128, 128, 128),
]
TEST_CASE_8_1 = [
- {"reader": ITKReader(channel_dim=0)},
+ {"reader": ITKReader(channel_dim=0) if has_itk else "itkreader"},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(384, 128, 128),
]
-
TEST_CASE_9 = [
- {"reader": ITKReader()},
+ {"reader": ITKReader() if has_itk else "itkreader"},
["test_image.nii.gz", "test_image2.nii.gz", "test_image3.nii.gz"],
(3, 128, 128, 128),
]
-TEST_CASE_10 = [{"reader": ITKReader(pixel_type=itk.UC)}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]
+TEST_CASE_10 = [
+ {"reader": ITKReader(pixel_type=itk_uc) if has_itk else "itkreader"},
+ "tests/testing_data/CT_DICOM",
+ (16, 16, 4),
+ (16, 16, 4),
+]
-TEST_CASE_11 = [{"reader": "ITKReader", "pixel_type": itk.UC}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]
+TEST_CASE_11 = [{"reader": "ITKReader", "pixel_type": itk_uc}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]
TEST_CASE_12 = [
- {"reader": "ITKReader", "pixel_type": itk.UC, "reverse_indexing": True},
+ {"reader": "ITKReader", "pixel_type": itk_uc, "reverse_indexing": True},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(4, 16, 16),
@@ -125,14 +133,14 @@ def get_data(self, _obj):
TEST_CASE_19 = [{"reader": PydicomReader()}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]
TEST_CASE_20 = [
- {"reader": "PydicomReader", "ensure_channel_first": True},
+ {"reader": "PydicomReader", "ensure_channel_first": True, "force": True},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(1, 16, 16, 4),
]
TEST_CASE_21 = [
- {"reader": "PydicomReader", "affine_lps_to_ras": True, "defer_size": "2 MB"},
+ {"reader": "PydicomReader", "affine_lps_to_ras": True, "defer_size": "2 MB", "force": True},
"tests/testing_data/CT_DICOM",
(16, 16, 4),
(16, 16, 4),
@@ -141,13 +149,13 @@ def get_data(self, _obj):
# test reader consistency between PydicomReader and ITKReader on dicom data
TEST_CASE_22 = ["tests/testing_data/CT_DICOM"]
-
TESTS_META = []
for track_meta in (False, True):
TESTS_META.append([{}, (128, 128, 128), track_meta])
TESTS_META.append([{"reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta])
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImage(unittest.TestCase):
@parameterized.expand(
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_3_1, TEST_CASE_4, TEST_CASE_4_1, TEST_CASE_5]
@@ -292,7 +300,7 @@ def test_my_reader(self):
def test_itk_meta(self):
"""test metadata from a directory"""
- out = LoadImage(image_only=True, reader="ITKReader", pixel_type=itk.UC, series_meta=True)(
+ out = LoadImage(image_only=True, reader="ITKReader", pixel_type=itk_uc, series_meta=True)(
"tests/testing_data/CT_DICOM"
)
idx = "0008|103e"
@@ -307,7 +315,7 @@ def test_channel_dim(self, input_param, filename, expected_shape):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, filename)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)
- result = LoadImage(image_only=True, **input_param)(filename)
+ result = LoadImage(image_only=True, **input_param)(filename) # with itk, meta has 'qto_xyz': itkMatrixF44
self.assertTupleEqual(
result.shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape
@@ -315,6 +323,7 @@ def test_channel_dim(self, input_param, filename, expected_shape):
self.assertEqual(result.meta["original_channel_dim"], input_param["channel_dim"])
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImageMeta(unittest.TestCase):
@classmethod
def setUpClass(cls):
diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py
index cd8b476a58b..8210d2f0d1d 100644
--- a/tests/test_load_imaged.py
+++ b/tests/test_load_imaged.py
@@ -15,7 +15,6 @@
import unittest
from pathlib import Path
-import itk
import nibabel as nib
import numpy as np
import torch
@@ -26,8 +25,11 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import Compose, EnsureChannelFirstD, FromMetaTensord, LoadImaged, SaveImageD
from monai.transforms.meta_utility.dictionary import ToMetaTensord
+from monai.utils import optional_import
from tests.utils import assert_allclose
+itk, has_itk = optional_import("itk", allow_namespace_pkg=True)
+
KEYS = ["image", "label", "extra"]
TEST_CASE_1 = [{"keys": KEYS}, (128, 128, 128)]
@@ -40,6 +42,7 @@
TESTS_META.append([{"keys": KEYS, "reader": "ITKReader", "fallback_only": False}, (128, 128, 128), track_meta])
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImaged(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_shape(self, input_param, expected_shape):
@@ -87,6 +90,7 @@ def test_no_file(self):
LoadImaged(keys="img", reader="nibabelreader", image_only=True)({"img": "unknown"})
+@unittest.skipUnless(has_itk, "itk not installed")
class TestConsistency(unittest.TestCase):
def _cmp(self, filename, ch_shape, reader_1, reader_2, outname, ext):
data_dict = {"img": filename}
@@ -147,6 +151,7 @@ def test_png(self):
self._cmp(filename, (3, 224, 256), "itkreader", "nibabelreader", output_name, ".png")
+@unittest.skipUnless(has_itk, "itk not installed")
class TestLoadImagedMeta(unittest.TestCase):
@classmethod
def setUpClass(cls):
diff --git a/tests/test_loader_semaphore.py b/tests/test_loader_semaphore.py
index bbb2d4eef67..85cf5593f8e 100644
--- a/tests/test_loader_semaphore.py
+++ b/tests/test_loader_semaphore.py
@@ -8,7 +8,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
"""this test should not generate errors or
UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores"""
import multiprocessing as mp
diff --git a/tests/test_local_normalized_cross_correlation_loss.py b/tests/test_local_normalized_cross_correlation_loss.py
index 8070c27f90d..e6052824a9c 100644
--- a/tests/test_local_normalized_cross_correlation_loss.py
+++ b/tests/test_local_normalized_cross_correlation_loss.py
@@ -21,7 +21,7 @@
TEST_CASES = [
[
- {"ndim": 1, "kernel_type": "rectangular", "reduction": "sum"},
+ {"spatial_dims": 1, "kernel_type": "rectangular", "reduction": "sum"},
{
"pred": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
"target": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
@@ -29,7 +29,7 @@
-1.0 * 3,
],
[
- {"ndim": 1, "kernel_type": "rectangular"},
+ {"spatial_dims": 1, "kernel_type": "rectangular"},
{
"pred": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
"target": torch.arange(0, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
@@ -37,7 +37,15 @@
-1.0,
],
[
- {"ndim": 2, "kernel_type": "rectangular"},
+ {"spatial_dims": 1, "kernel_type": "triangular", "smooth_dr": 0.1},
+ {
+ "pred": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
+ "target": torch.zeros(1, 2, 3).reshape(1, 1, -1).to(dtype=torch.float, device=device),
+ },
+ 0.0,
+ ],
+ [
+ {"spatial_dims": 2, "kernel_type": "rectangular"},
{
"pred": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device),
"target": torch.arange(0, 3).reshape(1, 1, -1, 1).expand(1, 1, 3, 3).to(dtype=torch.float, device=device),
@@ -45,7 +53,7 @@
-1.0,
],
[
- {"ndim": 3, "kernel_type": "rectangular"},
+ {"spatial_dims": 3, "kernel_type": "rectangular"},
{
"pred": torch.arange(0, 3)
.reshape(1, 1, -1, 1, 1)
@@ -59,7 +67,7 @@
-1.0,
],
[
- {"ndim": 3, "kernel_type": "rectangular"},
+ {"spatial_dims": 3, "kernel_type": "rectangular"},
{
"pred": torch.arange(0, 3)
.reshape(1, 1, -1, 1, 1)
@@ -74,7 +82,7 @@
-0.95801723,
],
[
- {"ndim": 3, "kernel_type": "triangular", "kernel_size": 5},
+ {"spatial_dims": 3, "kernel_type": "triangular", "kernel_size": 5},
{
"pred": torch.arange(0, 5)
.reshape(1, 1, -1, 1, 1)
@@ -89,7 +97,7 @@
-0.918672,
],
[
- {"ndim": 3, "kernel_type": "gaussian"},
+ {"spatial_dims": 3, "kernel_type": "gaussian"},
{
"pred": torch.arange(0, 3)
.reshape(1, 1, -1, 1, 1)
@@ -113,8 +121,8 @@ def test_shape(self, input_param, input_data, expected_val):
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
def test_ill_shape(self):
- loss = LocalNormalizedCrossCorrelationLoss(ndim=3)
- # ndim unmatch
+ loss = LocalNormalizedCrossCorrelationLoss(spatial_dims=3)
+ # spatial_dims unmatch
with self.assertRaisesRegex(ValueError, ""):
loss.forward(
torch.ones((1, 3, 3, 3), dtype=torch.float, device=device),
@@ -147,6 +155,5 @@ def test_ill_opts(self):
# loss = LocalNormalizedCrossCorrelationLoss(**input_param)
# test_script_save(loss, input_data["pred"], input_data["target"])
-
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_localnet.py b/tests/test_localnet.py
index 1a288fb4476..9296edab99d 100644
--- a/tests/test_localnet.py
+++ b/tests/test_localnet.py
@@ -20,7 +20,6 @@
device = "cuda" if torch.cuda.is_available() else "cpu"
-
TEST_CASE_LOCALNET_2D = [
[
{
diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py
index d036595241c..aed7976feb4 100644
--- a/tests/test_lr_finder.py
+++ b/tests/test_lr_finder.py
@@ -24,6 +24,7 @@
from monai.optimizers import LearningRateFinder
from monai.transforms import AddChanneld, Compose, LoadImaged, ScaleIntensityd, ToTensord
from monai.utils import optional_import, set_determinism
+from monai.utils.misc import MONAIEnvVars
from tests.utils import skip_if_downloading_fails
if TYPE_CHECKING:
@@ -47,7 +48,7 @@
class TestLRFinder(unittest.TestCase):
def setUp(self):
- self.root_dir = os.environ.get("MONAI_DATA_DIRECTORY")
+ self.root_dir = MONAIEnvVars.data_dir()
if not self.root_dir:
self.root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py
index a3e1ea9dd64..44f4c50c0f6 100644
--- a/tests/test_lr_scheduler.py
+++ b/tests/test_lr_scheduler.py
@@ -28,7 +28,11 @@ def forward(self, x):
TEST_CASE_LRSCHEDULER = [
- [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]]
+ [{"warmup_steps": 2, "t_total": 10}, [0.000, 0.500, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038]],
+ [
+ {"warmup_steps": 2, "t_total": 10, "warmup_multiplier": 0.1},
+ [0.1, 0.55, 1.00, 0.962, 0.854, 0.691, 0.500, 0.309, 0.146, 0.038],
+ ],
]
@@ -47,6 +51,13 @@ def test_shape(self, input_param, expected_lr):
for a, b in zip(lrs_1, expected_lr):
self.assertEqual(a, b, msg=f"LR is wrong ! expected {b}, got {a}")
+ def test_error(self):
+ """Should fail because warmup_multiplier is outside 0..1"""
+ net = SchedulerTestNet()
+ optimizer = torch.optim.Adam(net.parameters(), lr=1.0)
+ with self.assertRaises(ValueError):
+ WarmupCosineSchedule(optimizer, warmup_steps=2, t_total=10, warmup_multiplier=-1)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_map_label_value.py b/tests/test_map_label_value.py
index 0416858a74e..ef08f7eae30 100644
--- a/tests/test_map_label_value.py
+++ b/tests/test_map_label_value.py
@@ -16,7 +16,7 @@
from parameterized import parameterized
from monai.transforms import MapLabelValue
-from tests.utils import TEST_NDARRAYS
+from tests.utils import TEST_NDARRAYS, assert_allclose
TESTS = []
for p in TEST_NDARRAYS:
@@ -70,7 +70,7 @@ class TestMapLabelValue(unittest.TestCase):
def test_shape(self, input_param, input_data, expected_value):
result = MapLabelValue(**input_param)(input_data)
if isinstance(expected_value, torch.Tensor):
- torch.testing.assert_allclose(result, expected_value)
+ assert_allclose(result, expected_value)
else:
np.testing.assert_equal(result, expected_value)
self.assertTupleEqual(result.shape, expected_value.shape)
diff --git a/tests/test_masked_inference_wsi_dataset.py b/tests/test_masked_inference_wsi_dataset.py
index c29b95a2d8a..c424edd8970 100644
--- a/tests/test_masked_inference_wsi_dataset.py
+++ b/tests/test_masked_inference_wsi_dataset.py
@@ -10,6 +10,7 @@
# limitations under the License.
import os
+import tempfile
import unittest
from unittest import skipUnless
@@ -30,23 +31,21 @@
FILE_NAME = f"temp_{base_name}"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", FILE_NAME + extension)
-MASK1 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tissue_mask1.npy")
-MASK2 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tissue_mask2.npy")
-MASK4 = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tissue_mask4.npy")
+MASK1, MASK2, MASK4 = "mask1.npy", "mask2.npy", "mask4.npy"
HEIGHT = 32914
WIDTH = 46000
-def prepare_data():
+def prepare_data(*masks):
mask = np.zeros((HEIGHT // 2, WIDTH // 2))
mask[100, 100] = 1
- np.save(MASK1, mask)
+ np.save(masks[0], mask)
mask[100, 101] = 1
- np.save(MASK2, mask)
+ np.save(masks[1], mask)
mask[100:102, 100:102] = 1
- np.save(MASK4, mask)
+ np.save(masks[2], mask)
TEST_CASE_0 = [
@@ -134,7 +133,6 @@ def prepare_data():
],
]
-
TEST_CASE_OPENSLIDE_0 = [
{"data": [{"image": FILE_PATH, "mask": MASK1}], "patch_size": 1, "image_reader_name": "OpenSlide"},
[{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "name": FILE_NAME, "mask_location": [100, 100]}],
@@ -157,17 +155,24 @@ def prepare_data():
]
+@skip_if_quick
class TestMaskedInferenceWSIDataset(unittest.TestCase):
def setUp(self):
- prepare_data()
+ self.base_dir = tempfile.TemporaryDirectory()
+ prepare_data(*[os.path.join(self.base_dir.name, m) for m in [MASK1, MASK2, MASK4]])
hash_type = testing_data_config("images", FILE_KEY, "hash_type")
hash_val = testing_data_config("images", FILE_KEY, "hash_val")
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
+ def tearDown(self):
+ self.base_dir.cleanup()
+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
@skipUnless(has_cim, "Requires CuCIM")
@skip_if_quick
def test_read_patches_cucim(self, input_parameters, expected):
+ for m in input_parameters["data"]:
+ m["mask"] = os.path.join(self.base_dir.name, m["mask"])
dataset = MaskedInferenceWSIDataset(**input_parameters)
self.compare_samples_expected(dataset, expected)
@@ -175,15 +180,17 @@ def test_read_patches_cucim(self, input_parameters, expected):
@skipUnless(has_osl, "Requires OpenSlide")
@skip_if_quick
def test_read_patches_openslide(self, input_parameters, expected):
+ for m in input_parameters["data"]:
+ m["mask"] = os.path.join(self.base_dir.name, m["mask"])
dataset = MaskedInferenceWSIDataset(**input_parameters)
self.compare_samples_expected(dataset, expected)
def compare_samples_expected(self, dataset, expected):
- for i in range(len(dataset)):
- self.assertTupleEqual(dataset[i][0]["image"].shape, expected[i]["image"].shape)
- self.assertIsNone(assert_array_equal(dataset[i][0]["image"], expected[i]["image"]))
- self.assertEqual(dataset[i][0]["name"], expected[i]["name"])
- self.assertListEqual(dataset[i][0]["mask_location"], expected[i]["mask_location"])
+ for i, item in enumerate(dataset):
+ self.assertTupleEqual(item[0]["image"].shape, expected[i]["image"].shape)
+ self.assertIsNone(assert_array_equal(item[0]["image"], expected[i]["image"]))
+ self.assertEqual(item[0]["name"], expected[i]["name"])
+ self.assertListEqual(item[0]["mask_location"], expected[i]["mask_location"])
if __name__ == "__main__":
diff --git a/tests/test_masked_patch_wsi_dataset.py b/tests/test_masked_patch_wsi_dataset.py
index 797a39ee09b..9783c2d7cf6 100644
--- a/tests/test_masked_patch_wsi_dataset.py
+++ b/tests/test_masked_patch_wsi_dataset.py
@@ -30,7 +30,6 @@
_, has_codec = optional_import("imagecodecs")
has_tiff = has_tiff and has_codec
-
FILE_KEY = "wsi_img"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
@@ -87,13 +86,13 @@ def test_gen_patches(self, input_parameters, expected):
self.assertTrue(d1[ProbMapKeys.NAME] == os.path.basename(d2["image"]))
for i, sample in enumerate(dataset):
- self.assertEqual(sample["metadata"][WSIPatchKeys.LEVEL], expected["patch_level"])
- assert_array_equal(sample["metadata"][WSIPatchKeys.SIZE], expected["patch_size"])
+ self.assertEqual(sample["image"].meta[WSIPatchKeys.LEVEL], expected["patch_level"])
+ assert_array_equal(sample["image"].meta[WSIPatchKeys.SIZE], expected["patch_size"])
assert_array_equal(sample["image"].shape[1:], expected["patch_size"])
- self.assertTrue(sample["metadata"][WSIPatchKeys.LOCATION][0] >= 0)
- self.assertTrue(sample["metadata"][WSIPatchKeys.LOCATION][0] < expected["wsi_size"][0])
- self.assertTrue(sample["metadata"][WSIPatchKeys.LOCATION][1] >= 0)
- self.assertTrue(sample["metadata"][WSIPatchKeys.LOCATION][1] < expected["wsi_size"][1])
+ self.assertTrue(sample["image"].meta[WSIPatchKeys.LOCATION][0] >= 0)
+ self.assertTrue(sample["image"].meta[WSIPatchKeys.LOCATION][0] < expected["wsi_size"][0])
+ self.assertTrue(sample["image"].meta[WSIPatchKeys.LOCATION][1] >= 0)
+ self.assertTrue(sample["image"].meta[WSIPatchKeys.LOCATION][1] < expected["wsi_size"][1])
if i > 10:
break
diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py
index b14f6f01d33..060170e3bfa 100644
--- a/tests/test_mean_ensemble.py
+++ b/tests/test_mean_ensemble.py
@@ -68,7 +68,7 @@ def test_cuda_value(self):
img = img.to(torch.device("cuda:0"))
expected_value = expected_value.to(torch.device("cuda:0"))
result = MeanEnsemble(torch.tensor([[[1, 3]], [[3, 1]]]))(img)
- torch.testing.assert_allclose(result, expected_value)
+ assert_allclose(result, expected_value)
if __name__ == "__main__":
diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py
index b5e1569d653..f6d6286d359 100644
--- a/tests/test_mean_ensembled.py
+++ b/tests/test_mean_ensembled.py
@@ -73,7 +73,7 @@ class TestMeanEnsembled(unittest.TestCase):
@parameterized.expand(TESTS)
def test_value(self, input_param, data, expected_value):
result = MeanEnsembled(**input_param)(data)
- torch.testing.assert_allclose(result["output"], expected_value)
+ assert_allclose(result["output"], expected_value)
def test_cuda_value(self):
img = torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2])
diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py
new file mode 100644
index 00000000000..cf6286bdfe2
--- /dev/null
+++ b/tests/test_median_filter.py
@@ -0,0 +1,55 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+
+from monai.networks.layers import MedianFilter
+
+
+class MedianFilterTestCase(unittest.TestCase):
+ def test_3d_big(self):
+ a = torch.ones(1, 1, 2, 3, 5)
+ g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
+
+ expected = a.numpy()
+ out = g(a).cpu().numpy()
+ np.testing.assert_allclose(out, expected, rtol=1e-5)
+
+ def test_3d(self):
+ a = torch.ones(1, 1, 4, 3, 4)
+ g = MedianFilter(1).to(torch.device("cpu:0"))
+
+ expected = a.numpy()
+ out = g(a).cpu().numpy()
+ np.testing.assert_allclose(out, expected, rtol=1e-5)
+
+ def test_3d_radii(self):
+ a = torch.ones(1, 1, 4, 3, 2)
+ g = MedianFilter([3, 2, 1]).to(torch.device("cpu:0"))
+
+ expected = a.numpy()
+ out = g(a).cpu().numpy()
+ np.testing.assert_allclose(out, expected, rtol=1e-5)
+ if torch.cuda.is_available():
+ g = MedianFilter([3, 2, 1]).to(torch.device("cuda:0"))
+ np.testing.assert_allclose(g(a.cuda()).cpu().numpy(), expected, rtol=1e-2)
+
+ def test_wrong_args(self):
+ with self.assertRaisesRegex(ValueError, ""):
+ MedianFilter([3, 2]).to(torch.device("cpu:0"))
+ MedianFilter([3, 2, 1]).to(torch.device("cpu:0")) # test init
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_median_smooth.py b/tests/test_median_smooth.py
new file mode 100644
index 00000000000..87d29482ddc
--- /dev/null
+++ b/tests/test_median_smooth.py
@@ -0,0 +1,39 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from parameterized import parameterized
+
+from monai.transforms import MedianSmooth
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+TESTS = []
+
+for p in TEST_NDARRAYS:
+ TESTS.append(
+ [
+ {"radius": 1},
+ p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
+ p([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
+ ]
+ )
+
+
+class TestMedianSmooth(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_value(self, argments, image, expected_data):
+ result = MedianSmooth(**argments)(image)
+ assert_allclose(result, expected_data, atol=1e-4, rtol=1e-4, type_test="tensor")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_median_smoothd.py b/tests/test_median_smoothd.py
new file mode 100644
index 00000000000..811e833a90d
--- /dev/null
+++ b/tests/test_median_smoothd.py
@@ -0,0 +1,63 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import MedianSmoothd
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+TESTS = []
+for p in TEST_NDARRAYS[0:1]:
+ TESTS.append(
+ [
+ {"keys": "img", "radius": [0, 1]},
+ {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]))},
+ np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]]]),
+ ]
+ )
+
+ TESTS.append(
+ [
+ {"keys": "img", "radius": 1},
+ {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},
+ np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
+ ]
+ )
+
+ TESTS.append(
+ [
+ {"keys": "img", "radius": [1, 1]},
+ {"img": p(np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]))},
+ np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
+ ]
+ )
+
+ TESTS.append(
+ [
+ {"keys": "img", "radius": [1, 1, 1]},
+ {"img": p(np.array([[[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]]))},
+ np.array([[[[2, 2, 2], [3, 3, 3], [3, 3, 3]], [[4, 4, 4], [4, 4, 4], [5, 5, 5]]]]),
+ ]
+ )
+
+
+class TestMedianSmoothd(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_value(self, arguments, image, expected_data):
+ result = MedianSmoothd(**arguments)(image)
+ assert_allclose(result["img"], expected_data, rtol=1e-4, type_test="tensor")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py
index ccf2753ec3a..20d25ef61c9 100644
--- a/tests/test_meta_tensor.py
+++ b/tests/test_meta_tensor.py
@@ -422,21 +422,7 @@ def test_decollate(self, dtype):
def test_str(self):
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
- s1 = str(t)
- s2 = t.__repr__()
- expected_out = (
- "tensor([1.])\n"
- + "Metadata\n"
- + "\tfname: filename\n"
- + "\taffine: 1\n"
- + "\tspace: RAS\n"
- + "\n"
- + "Applied operations\n"
- + "[]\n"
- + "Is batch?: False"
- )
- for s in (s1, s2):
- self.assertEqual(s, expected_out)
+ self.assertEqual(str(t), "tensor([1.])")
def test_astype(self):
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
@@ -509,11 +495,20 @@ def test_construct_with_pre_applied_transforms(self):
m = MetaTensor(im, applied_operations=data["im"].applied_operations)
self.assertEqual(len(m.applied_operations), len(tr.transforms))
+ def test_pending_ops(self):
+ m, _ = self.get_im()
+ self.assertEqual(m.pending_operations, [])
+ self.assertEqual(m.peek_pending_shape(), (10, 8))
+ self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
+ m.push_pending_operation({})
+ self.assertEqual(m.peek_pending_shape(), (10, 8))
+ self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
+
@parameterized.expand(TESTS)
def test_multiprocessing(self, device=None, dtype=None):
"""multiprocessing sharing with 'device' and 'dtype'"""
buf = io.BytesIO()
- t = MetaTensor([0.0, 0.0], device=device, dtype=dtype)
+ t = MetaTensor([0, 0] if dtype in (torch.int32, torch.int64) else [0.0, 0.0], device=device, dtype=dtype)
t.is_batch = True
if t.is_cuda:
with self.assertRaises(NotImplementedError):
@@ -532,7 +527,9 @@ def test_array_function(self, device="cpu", dtype=float):
assert_allclose(np.sum(a), np.sum(b))
assert_allclose(np.sum(a, axis=1), np.sum(b, axis=1))
assert_allclose(np.linalg.qr(a), np.linalg.qr(b))
- c = MetaTensor([1.0, 2.0, 3.0], device=device, dtype=dtype)
+ c = MetaTensor(
+ [1, 2, 3] if dtype in (torch.int32, torch.int64) else [1.0, 2.0, 3.0], device=device, dtype=dtype
+ )
assert_allclose(np.argwhere(c == 1.0).astype(int).tolist(), [[0]])
assert_allclose(np.concatenate([c, c]), np.asarray([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]))
if pytorch_after(1, 8, 1):
@@ -544,7 +541,7 @@ def test_array_function(self, device="cpu", dtype=float):
@parameterized.expand(TESTS)
def test_numpy(self, device=None, dtype=None):
"""device, dtype"""
- t = MetaTensor([0.0], device=device, dtype=dtype)
+ t = MetaTensor([0 if dtype in (torch.int32, torch.int64) else 0.0], device=device, dtype=dtype)
self.assertIsInstance(t, MetaTensor)
assert_allclose(t.array, np.asarray([0.0]))
t.array = np.asarray([1.0])
@@ -554,7 +551,7 @@ def test_numpy(self, device=None, dtype=None):
self.check_meta(t, MetaTensor([2.0]))
assert_allclose(t.as_tensor(), torch.as_tensor([2.0]))
if not t.is_cuda:
- t.array[0] = torch.as_tensor(3.0, device=device, dtype=dtype)
+ t.array[0] = torch.as_tensor(3 if dtype in (torch.int32, torch.int64) else 3.0, device=device, dtype=dtype)
self.check_meta(t, MetaTensor([3.0]))
assert_allclose(t.as_tensor(), torch.as_tensor([3.0]))
diff --git a/tests/test_milmodel.py b/tests/test_milmodel.py
index ad04e96c600..2d58af4a2b3 100644
--- a/tests/test_milmodel.py
+++ b/tests/test_milmodel.py
@@ -17,13 +17,12 @@
from monai.networks import eval_mode
from monai.networks.nets import MILModel
from monai.utils.module import optional_import
-from tests.utils import test_script_save
+from tests.utils import skip_if_downloading_fails, test_script_save
models, _ = optional_import("torchvision.models")
device = "cuda" if torch.cuda.is_available() else "cpu"
-
TEST_CASE_MILMODEL = []
for num_classes in [1, 5]:
for mil_mode in ["mean", "max", "att", "att_trans", "att_trans_pyramid"]:
@@ -34,7 +33,6 @@
]
TEST_CASE_MILMODEL.append(test_case)
-
for trans_blocks in [1, 3]:
test_case = [
{"num_classes": 5, "pretrained": False, "trans_blocks": trans_blocks, "trans_dropout": 0.5},
@@ -65,7 +63,8 @@
class TestMilModel(unittest.TestCase):
@parameterized.expand(TEST_CASE_MILMODEL)
def test_shape(self, input_param, input_shape, expected_shape):
- net = MILModel(**input_param).to(device)
+ with skip_if_downloading_fails():
+ net = MILModel(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape, dtype=torch.float).to(device))
self.assertEqual(result.shape, expected_shape)
diff --git a/tests/test_monai_env_vars.py b/tests/test_monai_env_vars.py
new file mode 100644
index 00000000000..663dcdd98d9
--- /dev/null
+++ b/tests/test_monai_env_vars.py
@@ -0,0 +1,41 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+from monai.utils.misc import MONAIEnvVars
+
+
+class TestMONAIEnvVars(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super(__class__, cls).setUpClass()
+ cls.orig_value = os.environ.get("MONAI_DEBUG")
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.orig_value is not None:
+ os.environ["MONAI_DEBUG"] = cls.orig_value
+ else:
+ os.environ.pop("MONAI_DEBUG")
+ print("MONAI debug value:", os.environ.get("MONAI_DEBUG"))
+ super(__class__, cls).tearDownClass()
+
+ def test_monai_env_vars(self):
+ for debug in (False, True):
+ os.environ["MONAI_DEBUG"] = str(debug)
+ self.assertEqual(os.environ.get("MONAI_DEBUG"), str(debug))
+ self.assertEqual(MONAIEnvVars.debug(), debug)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_net_adapter.py b/tests/test_net_adapter.py
index 39201fb6007..c3abc8d142b 100644
--- a/tests/test_net_adapter.py
+++ b/tests/test_net_adapter.py
@@ -16,6 +16,7 @@
from monai.networks import eval_mode
from monai.networks.nets import NetAdapter, resnet18
+from tests.utils import test_script_save
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -50,6 +51,16 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)
+ @parameterized.expand([TEST_CASE_0])
+ def test_script(self, input_param, input_shape, expected_shape):
+ spatial_dims = input_param["dim"]
+ stride = (1, 2, 2)[:spatial_dims]
+ model = resnet18(spatial_dims=spatial_dims, conv1_t_stride=stride)
+ input_param["model"] = model
+ net = NetAdapter(**input_param).to("cpu")
+ test_data = torch.randn(input_shape).to("cpu")
+ test_script_save(net, test_data)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py
index 419e1202d00..327f0cfbf01 100644
--- a/tests/test_network_consistency.py
+++ b/tests/test_network_consistency.py
@@ -21,6 +21,7 @@
import monai.networks.nets as nets
from monai.utils import set_determinism
+from tests.utils import assert_allclose
extra_test_data_dir = os.environ.get("MONAI_EXTRA_TEST_DATA")
@@ -76,7 +77,7 @@ def check_output_consistency(self, actual, expected):
for a, e in zip(actual, expected):
self.check_output_consistency(a, e)
else:
- torch.testing.assert_allclose(actual, expected)
+ assert_allclose(actual, expected, rtol=5e-2, atol=1e-3)
if __name__ == "__main__":
diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py
index fae53394c37..7da50617d94 100644
--- a/tests/test_nifti_rw.py
+++ b/tests/test_nifti_rw.py
@@ -30,14 +30,6 @@
[[-5.3, 0.0, 0.0, 102.01], [0.0, 0.52, 2.17, -7.50], [-0.0, 1.98, -0.26, -23.12], [0.0, 0.0, 0.0, 1.0]]
)
)
- # TESTS.append(
- # [
- # TEST_IMAGE,
- # TEST_AFFINE,
- # dict(reader="NibabelReader", image_only=False, as_closest_canonical=True),
- # np.arange(24).reshape((2, 4, 3)),
- # ]
- # )
TESTS.append(
[
TEST_IMAGE,
@@ -165,8 +157,8 @@ def test_write_2d(self):
writer_obj.set_metadata({"affine": np.diag([1, 1, 1]), "original_affine": np.diag([1.4, 1, 1])})
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]])
- np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)
image_name = os.path.join(out_dir, "test1.nii.gz")
img = np.arange(5).reshape((1, 5))
@@ -176,8 +168,8 @@ def test_write_2d(self):
)
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]])
- np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1]), atol=1e-4, rtol=1e-4)
def test_write_3d(self):
with tempfile.TemporaryDirectory() as out_dir:
@@ -189,8 +181,8 @@ def test_write_3d(self):
writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 1]), "original_affine": np.diag([1.4, 1, 1, 1])})
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]])
- np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[[0, 1, 2], [3, 4, 5]]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)
image_name = os.path.join(out_dir, "test1.nii.gz")
img = p(np.arange(5).reshape((1, 1, 5)))
@@ -200,8 +192,8 @@ def test_write_3d(self):
)
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]])
- np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)
def test_write_4d(self):
with tempfile.TemporaryDirectory() as out_dir:
@@ -213,8 +205,8 @@ def test_write_4d(self):
writer_obj.set_metadata({"affine": np.diag([1.4, 1, 1, 1]), "original_affine": np.diag([1, 1.4, 1, 1])})
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]])
- np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[[[0, 1], [2, 3], [4, 5]]]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1, 1.4, 1, 1]), atol=1e-4, rtol=1e-4)
image_name = os.path.join(out_dir, "test1.nii.gz")
img = p(np.arange(5).reshape((1, 1, 5, 1)))
@@ -224,8 +216,8 @@ def test_write_4d(self):
)
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]])
- np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]))
+ np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]], atol=1e-4, rtol=1e-4)
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)
def test_write_5d(self):
with tempfile.TemporaryDirectory() as out_dir:
@@ -240,8 +232,10 @@ def test_write_5d(self):
np.testing.assert_allclose(
out.get_fdata(),
np.array([[[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]], [[8.0, 9.0], [10.0, 11.0]]]]]),
+ atol=1e-4,
+ rtol=1e-4,
)
- np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]))
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4)
image_name = os.path.join(out_dir, "test1.nii.gz")
img = p(np.arange(10).reshape((1, 1, 5, 1, 2)))
@@ -249,8 +243,10 @@ def test_write_5d(self):
writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3])})
writer_obj.write(image_name, verbose=True)
out = nib.load(image_name)
- np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]]))
- np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]))
+ np.testing.assert_allclose(
+ out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]]), atol=1e-4, rtol=1e-4
+ )
+ np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
diff --git a/tests/test_numpy_reader.py b/tests/test_numpy_reader.py
index bb7686f67d9..d220a67c920 100644
--- a/tests/test_numpy_reader.py
+++ b/tests/test_numpy_reader.py
@@ -15,10 +15,10 @@
import unittest
import numpy as np
-import torch
from monai.data import DataLoader, Dataset, NumpyReader
from monai.transforms import LoadImaged
+from tests.utils import assert_allclose
class TestNumpyReader(unittest.TestCase):
@@ -110,7 +110,7 @@ def test_dataloader(self):
)
for d in loader:
for c in d["image"]:
- torch.testing.assert_allclose(c, test_data)
+ assert_allclose(c, test_data, type_test=False)
def test_channel_dim(self):
test_data = np.random.randint(0, 256, size=[3, 4, 5, 2])
diff --git a/tests/test_nvtx_decorator.py b/tests/test_nvtx_decorator.py
index 9932b678c9e..7dd2dd81b5d 100644
--- a/tests/test_nvtx_decorator.py
+++ b/tests/test_nvtx_decorator.py
@@ -39,7 +39,6 @@
_, has_tvt = optional_import("torchvision.transforms")
_, has_cut = optional_import("cucim.core.operations.expose.transform")
-
TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)]
TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)]
diff --git a/tests/test_nvtx_transform.py b/tests/test_nvtx_transform.py
index 01a069ed8a1..fd784f6d323 100644
--- a/tests/test_nvtx_transform.py
+++ b/tests/test_nvtx_transform.py
@@ -34,7 +34,6 @@
_, has_nvtx = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
-
TEST_CASE_ARRAY_0 = [np.random.randn(3, 3)]
TEST_CASE_ARRAY_1 = [np.random.randn(3, 10, 10)]
TEST_CASE_DICT_0 = [{"image": np.random.randn(3, 3)}]
diff --git a/tests/test_occlusion_sensitivity.py b/tests/test_occlusion_sensitivity.py
index f258dfc5570..02e34704f17 100644
--- a/tests/test_occlusion_sensitivity.py
+++ b/tests/test_occlusion_sensitivity.py
@@ -10,6 +10,7 @@
# limitations under the License.
import unittest
+from typing import Any, List
import torch
from parameterized import parameterized
@@ -17,6 +18,14 @@
from monai.networks.nets import DenseNet, DenseNet121
from monai.visualize import OcclusionSensitivity
+
+class DenseNetAdjoint(DenseNet121):
+ def __call__(self, x, adjoint_info):
+ if adjoint_info != 42:
+ raise ValueError
+ return super().__call__(x)
+
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
out_channels_2d = 4
out_channels_3d = 3
@@ -25,44 +34,75 @@
model_3d = DenseNet(
spatial_dims=3, in_channels=1, out_channels=out_channels_3d, init_features=2, growth_rate=2, block_config=(6,)
).to(device)
+model_2d_adjoint = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=out_channels_2d).to(device)
model_2d.eval()
model_2d_2c.eval()
model_3d.eval()
+model_2d_adjoint.eval()
-# 2D w/ bounding box
-TEST_CASE_0 = [
- {"nn_module": model_2d},
- {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 40, 1, 62]},
- (1, 1, 39, 62, out_channels_2d),
- (1, 1, 39, 62),
-]
-# 3D w/ bounding box and stride
-TEST_CASE_1 = [
- {"nn_module": model_3d, "n_batch": 10, "stride": (2, 1, 2), "mask_size": (16, 15, 14)},
- {"x": torch.rand(1, 1, 6, 6, 6).to(device), "b_box": [-1, -1, 2, 3, -1, -1, -1, -1]},
- (1, 1, 2, 6, 6, out_channels_3d),
- (1, 1, 2, 6, 6),
-]
+TESTS: List[Any] = []
+TESTS_FAIL: List[Any] = []
-TEST_CASE_FAIL_0 = [ # 2D should fail, since 3 stride values given
- {"nn_module": model_2d, "n_batch": 10, "stride": (2, 2, 2)},
- {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [-1, -1, 2, 3, -1, -1]},
-]
-
-TEST_CASE_FAIL_1 = [ # 2D should fail, since stride is not a factor of image size
- {"nn_module": model_2d, "stride": 3},
- {"x": torch.rand(1, 1, 48, 64).to(device)},
-]
-TEST_MULTI_CHANNEL = [
- {"nn_module": model_2d_2c, "per_channel": False},
- {"x": torch.rand(1, 2, 48, 64).to(device)},
- (1, 1, 48, 64, out_channels_2d),
- (1, 1, 48, 64),
-]
+# 2D w/ bounding box with all modes
+for mode in ("gaussian", "mean_patch", "mean_img"):
+ TESTS.append(
+ [
+ {"nn_module": model_2d, "mode": mode},
+ {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [2, 40, 1, 62]},
+ (1, out_channels_2d, 38, 61),
+ (1, 1, 38, 61),
+ ]
+ )
+# 3D w/ bounding box
+TESTS.append(
+ [
+ {"nn_module": model_3d, "n_batch": 10, "mask_size": (16, 15, 14)},
+ {"x": torch.rand(1, 1, 64, 32, 16).to(device), "b_box": [2, 43, -1, -1, -1, -1]},
+ (1, out_channels_3d, 41, 32, 16),
+ (1, 1, 41, 32, 16),
+ ]
+)
+TESTS.append(
+ [
+ {"nn_module": model_3d, "n_batch": 10},
+ {"x": torch.rand(1, 1, 6, 7, 8).to(device), "b_box": [1, 3, -1, -1, -1, -1]},
+ (1, out_channels_3d, 2, 7, 8),
+ (1, 1, 2, 7, 8),
+ ]
+)
+TESTS.append(
+ [
+ {"nn_module": model_2d_2c},
+ {"x": torch.rand(1, 2, 48, 64).to(device)},
+ (1, out_channels_2d, 48, 64),
+ (1, 1, 48, 64),
+ ]
+)
+# 2D w/ bounding box and adjoint
+TESTS.append(
+ [
+ {"nn_module": model_2d_adjoint},
+ {"x": torch.rand(1, 1, 48, 64).to(device), "b_box": [2, 40, 1, 62], "adjoint_info": 42},
+ (1, out_channels_2d, 38, 61),
+ (1, 1, 38, 61),
+ ]
+)
+# 2D should fail: bbox makes image too small
+TESTS_FAIL.append(
+ [{"nn_module": model_2d, "n_batch": 10, "mask_size": 200}, {"x": torch.rand(1, 1, 48, 64).to(device)}, ValueError]
+)
+# 2D should fail: batch > 1
+TESTS_FAIL.append(
+ [{"nn_module": model_2d, "n_batch": 10, "mask_size": 100}, {"x": torch.rand(2, 1, 48, 64).to(device)}, ValueError]
+)
+# 2D should fail: unknown mode
+TESTS_FAIL.append(
+ [{"nn_module": model_2d, "mode": "test"}, {"x": torch.rand(1, 1, 48, 64).to(device)}, NotImplementedError]
+)
class TestComputeOcclusionSensitivity(unittest.TestCase):
- @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_MULTI_CHANNEL])
+ @parameterized.expand(TESTS)
def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expected_shape):
occ_sens = OcclusionSensitivity(**init_data)
m, most_prob = occ_sens(**call_data)
@@ -73,10 +113,10 @@ def test_shape(self, init_data, call_data, map_expected_shape, most_prob_expecte
self.assertGreaterEqual(most_prob.min(), 0)
self.assertLess(most_prob.max(), m.shape[-1])
- @parameterized.expand([TEST_CASE_FAIL_0, TEST_CASE_FAIL_1])
- def test_fail(self, init_data, call_data):
- occ_sens = OcclusionSensitivity(**init_data)
- with self.assertRaises(ValueError):
+ @parameterized.expand(TESTS_FAIL)
+ def test_fail(self, init_data, call_data, error_type):
+ with self.assertRaises(error_type):
+ occ_sens = OcclusionSensitivity(**init_data)
occ_sens(**call_data)
diff --git a/tests/test_one_of.py b/tests/test_one_of.py
index 29d13d7d0c6..2ea41c6e506 100644
--- a/tests/test_one_of.py
+++ b/tests/test_one_of.py
@@ -15,11 +15,15 @@
import numpy as np
from parameterized import parameterized
+from monai.data import MetaTensor
from monai.transforms import (
InvertibleTransform,
OneOf,
+ RandScaleIntensity,
RandScaleIntensityd,
+ RandShiftIntensity,
RandShiftIntensityd,
+ Resize,
Resized,
TraceableTransform,
Transform,
@@ -106,10 +110,10 @@ def __init__(self, keys):
KEYS = ["x", "y"]
TEST_INVERSES = [
- (OneOf((InvA(KEYS), InvB(KEYS))), True),
- (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True),
- (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True),
- (OneOf((NonInv(KEYS), NonInv(KEYS))), False),
+ (OneOf((InvA(KEYS), InvB(KEYS))), True, True),
+ (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False),
+ (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False),
+ (OneOf((NonInv(KEYS), NonInv(KEYS))), False, False),
]
@@ -136,6 +140,7 @@ def test_len_and_flatten(self):
def test_compose_flatten_does_not_affect_one_of(self):
p = Compose([A(), B(), OneOf([C(), Inv(KEYS), Compose([X(), Y()])])])
f = p.flatten()
+
# in this case the flattened transform should be the same.
def _match(a, b):
@@ -148,13 +153,17 @@ def _match(a, b):
_match(p, f)
@parameterized.expand(TEST_INVERSES)
- def test_inverse(self, transform, invertible):
- data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
+ def test_inverse(self, transform, invertible, use_metatensor):
+ data = {k: (i + 1) * 10.0 if not use_metatensor else MetaTensor((i + 1) * 10.0) for i, k in enumerate(KEYS)}
fwd_data = transform(data)
if invertible:
for k in KEYS:
- t = fwd_data[TraceableTransform.trace_key(k)][-1]
+ t = (
+ fwd_data[TraceableTransform.trace_key(k)][-1]
+ if not use_metatensor
+ else fwd_data[k].applied_operations[-1]
+ )
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
@@ -166,9 +175,11 @@ def test_inverse(self, transform, invertible):
if invertible:
for k in KEYS:
# check transform was removed
- self.assertTrue(
- len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
- )
+ if not use_metatensor:
+ self.assertTrue(
+ len(fwd_inv_data[TraceableTransform.trace_key(k)])
+ < len(fwd_data[TraceableTransform.trace_key(k)])
+ )
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
@@ -186,15 +197,34 @@ def test_inverse_compose(self):
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
]
),
+ OneOf(
+ [
+ RandScaleIntensityd(keys="img", factors=0.5, prob=1.0),
+ RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
+ ]
+ ),
]
)
transform.set_random_state(seed=0)
result = transform({"img": np.ones((1, 101, 102, 103))})
-
result = transform.inverse(result)
# invert to the original spatial shape
self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103))
+ def test_inverse_metatensor(self):
+ transform = Compose(
+ [
+ Resize(spatial_size=[100, 100, 100]),
+ OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
+ OneOf([RandScaleIntensity(factors=0.5, prob=1.0), RandShiftIntensity(offsets=0.5, prob=1.0)]),
+ ]
+ )
+ transform.set_random_state(seed=0)
+ result = transform(np.ones((1, 101, 102, 103)))
+ self.assertTupleEqual(result.shape, (1, 100, 100, 100))
+ result = transform.inverse(result)
+ self.assertTupleEqual(result.shape, (1, 101, 102, 103))
+
def test_one_of(self):
p = OneOf((A(), B(), C()), (1, 2, 1))
counts = [0] * 3
diff --git a/tests/test_ori_ras_lps.py b/tests/test_ori_ras_lps.py
index 4ed223bf5b5..d0a9b034e44 100644
--- a/tests/test_ori_ras_lps.py
+++ b/tests/test_ori_ras_lps.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import numpy as np
diff --git a/tests/test_orientation.py b/tests/test_orientation.py
index 7c4a1863c7f..979f6ae4850 100644
--- a/tests/test_orientation.py
+++ b/tests/test_orientation.py
@@ -167,7 +167,6 @@
for device in TEST_DEVICES:
TESTS_TORCH.append([{"axcodes": "LPS"}, torch.zeros((1, 3, 4, 5)), track_meta, *device])
-
ILL_CASES = [
# too short axcodes
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py
index a46c117b754..9574afcceaf 100644
--- a/tests/test_patch_dataset.py
+++ b/tests/test_patch_dataset.py
@@ -59,7 +59,7 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
- [[[1.338681, 2.338681, 3.338681], [5.338681, 6.338681, 7.338681], [9.338681, 10.338681, 11.338681]]]
+ [[[-0.593095, 0.406905, 1.406905], [3.406905, 4.406905, 5.406905], [7.406905, 8.406905, 9.406905]]]
),
rtol=1e-5,
)
@@ -69,13 +69,7 @@ def test_loading_array(self):
np.testing.assert_allclose(
item[0],
np.array(
- [
- [
- [4.957847, 5.957847, 6.957847],
- [8.957847, 9.957847, 10.957847],
- [12.957847, 13.957847, 14.957847],
- ]
- ]
+ [[[0.234308, 1.234308, 2.234308], [4.234308, 5.234308, 6.234308], [8.234308, 9.234308, 10.234308]]]
),
rtol=1e-5,
)
diff --git a/tests/test_patch_wsi_dataset.py b/tests/test_patch_wsi_dataset.py
index 20d7f22988f..0ba1a4a6494 100644
--- a/tests/test_patch_wsi_dataset.py
+++ b/tests/test_patch_wsi_dataset.py
@@ -17,22 +17,28 @@
from numpy.testing import assert_array_equal
from parameterized import parameterized
-from monai.apps.pathology.data import PatchWSIDataset
-from monai.utils import optional_import
+from monai.apps.pathology.data import PatchWSIDataset as PatchWSIDatasetDeprecated
+from monai.data import PatchWSIDataset
+from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader
+from monai.utils import deprecated, optional_import
+from monai.utils.enums import WSIPatchKeys
from tests.utils import download_url_or_skip_test, testing_data_config
-_cucim, has_cim = optional_import("cucim")
-has_cim = has_cim and hasattr(_cucim, "CuImage")
-_, has_osl = optional_import("openslide")
+cucim, has_cim = optional_import("cucim")
+has_cim = has_cim and hasattr(cucim, "CuImage")
+openslide, has_osl = optional_import("openslide")
+imwrite, has_tiff = optional_import("tifffile", name="imwrite")
+_, has_codec = optional_import("imagecodecs")
+has_tiff = has_tiff and has_codec
FILE_KEY = "wsi_img"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension)
-TEST_CASE_0 = [
+TEST_CASE_DEP_0 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -41,9 +47,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_0_L1 = [
+TEST_CASE_DEP_0_L1 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -53,9 +59,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_0_L2 = [
+TEST_CASE_DEP_0_L2 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -65,10 +71,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-
-TEST_CASE_1 = [
+TEST_CASE_DEP_1 = [
{
- "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
@@ -82,10 +87,9 @@
],
]
-
-TEST_CASE_1_L0 = [
+TEST_CASE_DEP_1_L0 = [
{
- "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
@@ -100,10 +104,9 @@
],
]
-
-TEST_CASE_1_L1 = [
+TEST_CASE_DEP_1_L1 = [
{
- "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
@@ -117,9 +120,9 @@
{"image": np.array([[[246]], [[242]], [[243]]], dtype=np.uint8), "label": np.array([[[1]]])},
],
]
-TEST_CASE_2 = [
+TEST_CASE_DEP_2 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": 1,
"grid_shape": 1,
"patch_size": 1,
@@ -128,9 +131,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_3 = [
+TEST_CASE_DEP_3 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [[[0, 1], [1, 0]]]}],
"region_size": 1,
"grid_shape": 1,
"patch_size": 1,
@@ -139,9 +142,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}],
]
-TEST_CASE_OPENSLIDE_0 = [
+TEST_CASE_DEP_OPENSLIDE_0 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -150,9 +153,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_OPENSLIDE_0_L0 = [
+TEST_CASE_DEP_OPENSLIDE_0_L0 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -162,9 +165,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_OPENSLIDE_0_L1 = [
+TEST_CASE_DEP_OPENSLIDE_0_L1 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -174,10 +177,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-
-TEST_CASE_OPENSLIDE_0_L2 = [
+TEST_CASE_DEP_OPENSLIDE_0_L2 = [
{
- "data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
"region_size": (1, 1),
"grid_shape": (1, 1),
"patch_size": 1,
@@ -187,9 +189,9 @@
[{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1]]])}],
]
-TEST_CASE_OPENSLIDE_1 = [
+TEST_CASE_DEP_OPENSLIDE_1 = [
{
- "data": [{"image": FILE_PATH, "location": [10004, 20004], "label": [0, 0, 0, 1]}],
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [10004, 20004], "label": [0, 0, 0, 1]}],
"region_size": (8, 8),
"grid_shape": (2, 2),
"patch_size": 1,
@@ -203,53 +205,206 @@
],
]
+TEST_CASE_0 = [
+ {
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1], "patch_level": 0}],
+ "patch_size": (1, 1),
+ },
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
+]
-class TestPatchWSIDataset(unittest.TestCase):
- def setUp(self):
- hash_type = testing_data_config("images", FILE_KEY, "hash_type")
- hash_val = testing_data_config("images", FILE_KEY, "hash_val")
- download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
+TEST_CASE_0_L1 = [
+ {
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
+ "patch_size": (1, 1),
+ "patch_level": 1,
+ },
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
+]
+
+TEST_CASE_0_L2 = [
+ {
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
+ "patch_size": (1, 1),
+ "patch_level": 1,
+ },
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
+]
+TEST_CASE_1 = [
+ {"data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], WSIPatchKeys.SIZE.value: 1, "label": [1]}]},
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
+]
+
+TEST_CASE_2 = [
+ {
+ "data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [1]}],
+ "patch_size": 1,
+ "patch_level": 0,
+ },
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
+]
+TEST_CASE_3 = [
+ {"data": [{"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [[[0, 1], [1, 0]]]}], "patch_size": 1},
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
+]
+
+TEST_CASE_4 = [
+ {
+ "data": [
+ {"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [[[0, 1], [1, 0]]]},
+ {"image": FILE_PATH, WSIPatchKeys.LOCATION.value: [0, 0], "label": [[[1, 0], [0, 0]]]},
+ ],
+ "patch_size": 1,
+ },
+ [
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
+ ],
+]
+
+TEST_CASE_5 = [
+ {
+ "data": [
+ {
+ "image": FILE_PATH,
+ WSIPatchKeys.LOCATION.value: [0, 0],
+ "label": [[[0, 1], [1, 0]]],
+ WSIPatchKeys.SIZE.value: 1,
+ WSIPatchKeys.LEVEL.value: 1,
+ },
+ {
+ "image": FILE_PATH,
+ WSIPatchKeys.LOCATION.value: [100, 100],
+ "label": [[[1, 0], [0, 0]]],
+ WSIPatchKeys.SIZE.value: 1,
+ WSIPatchKeys.LEVEL.value: 1,
+ },
+ ]
+ },
+ [
+ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
+ {"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
+ ],
+]
+
+
+@skipUnless(has_cim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
+def setUpModule():
+ hash_type = testing_data_config("images", FILE_KEY, "hash_type")
+ hash_val = testing_data_config("images", FILE_KEY, "hash_val")
+ download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
+
+
+@deprecated(since="0.8", msg_suffix="use tests for `monai.data.PatchWSIDataset` instead, `PatchWSIDatasetTests`.")
+class TestPatchWSIDatasetDeprecated(unittest.TestCase):
@parameterized.expand(
[
- TEST_CASE_0,
- TEST_CASE_0_L1,
- TEST_CASE_0_L2,
- TEST_CASE_1,
- TEST_CASE_1_L0,
- TEST_CASE_1_L1,
- TEST_CASE_2,
- TEST_CASE_3,
+ TEST_CASE_DEP_0,
+ TEST_CASE_DEP_0_L1,
+ TEST_CASE_DEP_0_L2,
+ TEST_CASE_DEP_1,
+ TEST_CASE_DEP_1_L0,
+ TEST_CASE_DEP_1_L1,
+ TEST_CASE_DEP_2,
+ TEST_CASE_DEP_3,
]
)
@skipUnless(has_cim, "Requires CuCIM")
def test_read_patches_cucim(self, input_parameters, expected):
- dataset = PatchWSIDataset(**input_parameters)
+ dataset = PatchWSIDatasetDeprecated(**input_parameters)
samples = dataset[0]
- for i in range(len(samples)):
- self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape)
- self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape)
- self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"]))
- self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"]))
+ for i, item in enumerate(samples):
+ self.assertTupleEqual(item["label"].shape, expected[i]["label"].shape)
+ self.assertTupleEqual(item["image"].shape, expected[i]["image"].shape)
+ self.assertIsNone(assert_array_equal(item["label"], expected[i]["label"]))
+ self.assertIsNone(assert_array_equal(item["image"], expected[i]["image"]))
@parameterized.expand(
[
- TEST_CASE_OPENSLIDE_0,
- TEST_CASE_OPENSLIDE_0_L0,
- TEST_CASE_OPENSLIDE_0_L1,
- TEST_CASE_OPENSLIDE_0_L2,
- TEST_CASE_OPENSLIDE_1,
+ TEST_CASE_DEP_OPENSLIDE_0,
+ TEST_CASE_DEP_OPENSLIDE_0_L0,
+ TEST_CASE_DEP_OPENSLIDE_0_L1,
+ TEST_CASE_DEP_OPENSLIDE_0_L2,
+ TEST_CASE_DEP_OPENSLIDE_1,
]
)
@skipUnless(has_osl, "Requires OpenSlide")
def test_read_patches_openslide(self, input_parameters, expected):
- dataset = PatchWSIDataset(**input_parameters)
+ dataset = PatchWSIDatasetDeprecated(**input_parameters)
samples = dataset[0]
- for i in range(len(samples)):
- self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape)
- self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape)
- self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"]))
- self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"]))
+ for i, item in enumerate(samples):
+ self.assertTupleEqual(item["label"].shape, expected[i]["label"].shape)
+ self.assertTupleEqual(item["image"].shape, expected[i]["image"].shape)
+ self.assertIsNone(assert_array_equal(item["label"], expected[i]["label"]))
+ self.assertIsNone(assert_array_equal(item["image"], expected[i]["image"]))
+
+
+class PatchWSIDatasetTests:
+ class Tests(unittest.TestCase):
+ backend = None
+
+ @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
+ def test_read_patches_str(self, input_parameters, expected):
+ dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
+ sample = dataset[0]
+ self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
+ self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
+ self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
+ self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
+
+ @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
+ def test_read_patches_class(self, input_parameters, expected):
+ if self.backend == "openslide":
+ reader = OpenSlideWSIReader
+ elif self.backend == "cucim":
+ reader = CuCIMWSIReader
+ else:
+ raise ValueError("Unsupported backend: {self.backend}")
+ dataset = PatchWSIDataset(reader=reader, **input_parameters)
+ sample = dataset[0]
+ self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
+ self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
+ self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
+ self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
+
+ @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
+ def test_read_patches_object(self, input_parameters, expected):
+ if self.backend == "openslide":
+ reader = OpenSlideWSIReader(level=input_parameters.get("patch_level", 0))
+ elif self.backend == "cucim":
+ reader = CuCIMWSIReader(level=input_parameters.get("patch_level", 0))
+ else:
+ raise ValueError("Unsupported backend: {self.backend}")
+ dataset = PatchWSIDataset(reader=reader, **input_parameters)
+ sample = dataset[0]
+ self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
+ self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
+ self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
+ self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
+
+ @parameterized.expand([TEST_CASE_4, TEST_CASE_5])
+ def test_read_patches_str_multi(self, input_parameters, expected):
+ dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
+ for i, item in enumerate(dataset):
+ self.assertTupleEqual(item["label"].shape, expected[i]["label"].shape)
+ self.assertTupleEqual(item["image"].shape, expected[i]["image"].shape)
+ self.assertIsNone(assert_array_equal(item["label"], expected[i]["label"]))
+ self.assertIsNone(assert_array_equal(item["image"], expected[i]["image"]))
+
+
+@skipUnless(has_cim, "Requires cucim")
+class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests):
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = "cucim"
+
+
+@skipUnless(has_osl, "Requires openslide")
+class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = "openslide"
if __name__ == "__main__":
diff --git a/tests/test_patch_wsi_dataset_new.py b/tests/test_patch_wsi_dataset_new.py
deleted file mode 100644
index fee8a030689..00000000000
--- a/tests/test_patch_wsi_dataset_new.py
+++ /dev/null
@@ -1,181 +0,0 @@
-# Copyright (c) MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import unittest
-from unittest import skipUnless
-
-import numpy as np
-from numpy.testing import assert_array_equal
-from parameterized import parameterized
-
-from monai.data import PatchWSIDataset
-from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader
-from monai.utils import optional_import
-from tests.utils import download_url_or_skip_test, testing_data_config
-
-cucim, has_cucim = optional_import("cucim")
-has_cucim = has_cucim and hasattr(cucim, "CuImage")
-openslide, has_osl = optional_import("openslide")
-imwrite, has_tiff = optional_import("tifffile", name="imwrite")
-_, has_codec = optional_import("imagecodecs")
-has_tiff = has_tiff and has_codec
-
-FILE_KEY = "wsi_img"
-FILE_URL = testing_data_config("images", FILE_KEY, "url")
-base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
-FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension)
-
-TEST_CASE_0 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "label": [1], "patch_level": 0}], "patch_size": (1, 1)},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
-]
-
-TEST_CASE_0_L1 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "label": [1]}], "patch_size": (1, 1), "patch_level": 1},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
-]
-
-TEST_CASE_0_L2 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "label": [1]}], "patch_size": (1, 1), "patch_level": 1},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
-]
-TEST_CASE_1 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "patch_size": 1, "label": [1]}]},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
-]
-
-TEST_CASE_2 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "label": [1]}], "patch_size": 1, "patch_level": 0},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])},
-]
-
-TEST_CASE_3 = [
- {"data": [{"image": FILE_PATH, "patch_location": [0, 0], "label": [[[0, 1], [1, 0]]]}], "patch_size": 1},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
-]
-
-TEST_CASE_4 = [
- {
- "data": [
- {"image": FILE_PATH, "patch_location": [0, 0], "label": [[[0, 1], [1, 0]]]},
- {"image": FILE_PATH, "patch_location": [0, 0], "label": [[[1, 0], [0, 0]]]},
- ],
- "patch_size": 1,
- },
- [
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
- ],
-]
-
-TEST_CASE_5 = [
- {
- "data": [
- {
- "image": FILE_PATH,
- "patch_location": [0, 0],
- "label": [[[0, 1], [1, 0]]],
- "patch_size": 1,
- "patch_level": 1,
- },
- {
- "image": FILE_PATH,
- "patch_location": [100, 100],
- "label": [[[1, 0], [0, 0]]],
- "patch_size": 1,
- "patch_level": 1,
- },
- ]
- },
- [
- {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])},
- {"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])},
- ],
-]
-
-
-@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
-def setUpModule():
- hash_type = testing_data_config("images", FILE_KEY, "hash_type")
- hash_val = testing_data_config("images", FILE_KEY, "hash_val")
- download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
-
-
-class PatchWSIDatasetTests:
- class Tests(unittest.TestCase):
- backend = None
-
- @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
- def test_read_patches_str(self, input_parameters, expected):
- dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
- sample = dataset[0]
- self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
- self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
- self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
- self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
-
- @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
- def test_read_patches_class(self, input_parameters, expected):
- if self.backend == "openslide":
- reader = OpenSlideWSIReader
- elif self.backend == "cucim":
- reader = CuCIMWSIReader
- else:
- raise ValueError("Unsupported backend: {self.backend}")
- dataset = PatchWSIDataset(reader=reader, **input_parameters)
- sample = dataset[0]
- self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
- self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
- self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
- self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
-
- @parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
- def test_read_patches_object(self, input_parameters, expected):
- if self.backend == "openslide":
- reader = OpenSlideWSIReader(level=input_parameters.get("patch_level", 0))
- elif self.backend == "cucim":
- reader = CuCIMWSIReader(level=input_parameters.get("patch_level", 0))
- else:
- raise ValueError("Unsupported backend: {self.backend}")
- dataset = PatchWSIDataset(reader=reader, **input_parameters)
- sample = dataset[0]
- self.assertTupleEqual(sample["label"].shape, expected["label"].shape)
- self.assertTupleEqual(sample["image"].shape, expected["image"].shape)
- self.assertIsNone(assert_array_equal(sample["label"], expected["label"]))
- self.assertIsNone(assert_array_equal(sample["image"], expected["image"]))
-
- @parameterized.expand([TEST_CASE_4, TEST_CASE_5])
- def test_read_patches_str_multi(self, input_parameters, expected):
- dataset = PatchWSIDataset(reader=self.backend, **input_parameters)
- for i in range(len(dataset)):
- self.assertTupleEqual(dataset[i]["label"].shape, expected[i]["label"].shape)
- self.assertTupleEqual(dataset[i]["image"].shape, expected[i]["image"].shape)
- self.assertIsNone(assert_array_equal(dataset[i]["label"], expected[i]["label"]))
- self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"]))
-
-
-@skipUnless(has_cucim, "Requires cucim")
-class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests):
- @classmethod
- def setUpClass(cls):
- cls.backend = "cucim"
-
-
-@skipUnless(has_osl, "Requires openslide")
-class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests):
- @classmethod
- def setUpClass(cls):
- cls.backend = "openslide"
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_pathology_he_stain.py b/tests/test_pathology_he_stain.py
index 7b884315fc0..ac4e94144a2 100644
--- a/tests/test_pathology_he_stain.py
+++ b/tests/test_pathology_he_stain.py
@@ -49,7 +49,6 @@
np.array([[0.70710677, 0.18696113], [0.0, 0.0], [0.70710677, 0.98236734]]),
]
-
# input pixels all transparent and below the beta absorbance threshold
NORMALIZE_STAINS_TEST_CASE_1 = [np.full((3, 2, 3), 240)]
diff --git a/tests/test_pil_reader.py b/tests/test_pil_reader.py
index 0f7792a56c0..4f0b891b723 100644
--- a/tests/test_pil_reader.py
+++ b/tests/test_pil_reader.py
@@ -64,6 +64,7 @@ def test_converter(self, data_shape, filenames, expected_shape, meta_shape):
Image.fromarray(test_image.astype("uint8")).save(filenames[i])
reader = PILReader(converter=lambda image: image.convert("LA"))
result = reader.get_data(reader.read(filenames, mode="r"))
+ self.assertEqual(result[1]["format"], "none") # project-monai/monai issue#5251
# load image by PIL and compare the result
test_image = np.asarray(Image.open(filenames[0]).convert("LA"))
diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py
index 96051b5e826..e3836ed86fa 100644
--- a/tests/test_prepare_batch_default.py
+++ b/tests/test_prepare_batch_default.py
@@ -23,7 +23,7 @@ def forward(self, x: torch.Tensor):
class TestPrepareBatchDefault(unittest.TestCase):
- def test_content(self):
+ def test_dict_content(self):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataloader = [
{
@@ -50,6 +50,46 @@ def test_content(self):
assert_allclose(output["image"], torch.tensor([1, 2], device=device))
assert_allclose(output["label"], torch.tensor([3, 4], device=device))
+ def test_tensor_content(self):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dataloader = [torch.tensor([1, 2])]
+
+ # set up engine
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=dataloader,
+ epoch_length=1,
+ network=torch.nn.Identity(),
+ non_blocking=False,
+ prepare_batch=PrepareBatchDefault(),
+ decollate=False,
+ mode="eval",
+ )
+ evaluator.run()
+ output = evaluator.state.output
+ assert_allclose(output["image"], torch.tensor([1, 2], device=device))
+ self.assertTrue(output["label"] is None)
+
+ def test_pair_content(self):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dataloader = [(torch.tensor([1, 2]), torch.tensor([3, 4]))]
+
+ # set up engine
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=dataloader,
+ epoch_length=1,
+ network=torch.nn.Identity(),
+ non_blocking=False,
+ prepare_batch=PrepareBatchDefault(),
+ decollate=False,
+ mode="eval",
+ )
+ evaluator.run()
+ output = evaluator.state.output
+ assert_allclose(output["image"], torch.tensor([1, 2], device=device))
+ assert_allclose(output["label"], torch.tensor([3, 4], device=device))
+
def test_empty_data(self):
dataloader = []
evaluator = SupervisedEvaluator(
diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py
index 95d01d2a160..3c7532e9162 100644
--- a/tests/test_prepare_batch_default_dist.py
+++ b/tests/test_prepare_batch_default_dist.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
import torch
diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py
new file mode 100644
index 00000000000..9aed8e94c70
--- /dev/null
+++ b/tests/test_prepare_batch_hovernet.py
@@ -0,0 +1,66 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.pathology.engines import PrepareBatchHoVerNet
+from monai.engines import SupervisedEvaluator
+from monai.utils.enums import HoVerNetBranch
+from tests.utils import assert_allclose
+
+TEST_CASE_0 = [
+ {"extra_keys": ["extra_label1", "extra_label2"]},
+ {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16},
+]
+
+
+class TestNet(torch.nn.Module):
+ def forward(self, x: torch.Tensor):
+ return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}
+
+
+class TestPrepareBatchHoVerNet(unittest.TestCase):
+ @parameterized.expand([TEST_CASE_0])
+ def test_content(self, input_args, expected_value):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dataloader = [
+ {
+ "image": torch.tensor([1, 2]),
+ "label": torch.tensor([1, 2]),
+ "extra_label1": torch.tensor([3, 4]),
+ "extra_label2": 16,
+ }
+ ]
+ # set up engine
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=dataloader,
+ epoch_length=1,
+ network=TestNet(),
+ non_blocking=True,
+ prepare_batch=PrepareBatchHoVerNet(**input_args),
+ decollate=False,
+ )
+ evaluator.run()
+ output = evaluator.state.output
+ assert_allclose(output["image"], torch.tensor([1, 2], device=device))
+ for k, v in output["pred"].items():
+ if isinstance(v, torch.Tensor):
+ assert_allclose(v, expected_value[k].to(device))
+ else:
+ self.assertEqual(v, expected_value[k])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_pytorch_version_after.py b/tests/test_pytorch_version_after.py
index 68abb9571fb..be43e49f827 100644
--- a/tests/test_pytorch_version_after.py
+++ b/tests/test_pytorch_version_after.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import unittest
from parameterized import parameterized
diff --git a/tests/test_rand_cucim_dict_transform.py b/tests/test_rand_cucim_dict_transform.py
index a109bee845c..e72101e4706 100644
--- a/tests/test_rand_cucim_dict_transform.py
+++ b/tests/test_rand_cucim_dict_transform.py
@@ -41,7 +41,6 @@
np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),
]
-
TEST_CASE_RAND_ROTATE_2 = [
{"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)},
np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),
diff --git a/tests/test_rand_cucim_transform.py b/tests/test_rand_cucim_transform.py
index 30164e4170d..0f37b3f6cba 100644
--- a/tests/test_rand_cucim_transform.py
+++ b/tests/test_rand_cucim_transform.py
@@ -41,7 +41,6 @@
np.array([[[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]], [[1.0, 3.0], [0.0, 2.0]]], dtype=np.float32),
]
-
TEST_CASE_RAND_ROTATE_2 = [
{"name": "rand_image_rotate_90", "prob": 0.0, "max_k": 1, "spatial_axis": (-2, -1)},
np.array([[[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]]], dtype=np.float32),
diff --git a/tests/test_rand_elasticd_2d.py b/tests/test_rand_elasticd_2d.py
index 759ba2c4da4..d6f7a0cbba5 100644
--- a/tests/test_rand_elasticd_2d.py
+++ b/tests/test_rand_elasticd_2d.py
@@ -161,6 +161,8 @@ class TestRand2DElasticd(unittest.TestCase):
@parameterized.expand(TESTS)
def test_rand_2d_elasticd(self, input_param, input_data, expected_val):
g = Rand2DElasticd(**input_param)
+ if input_param.get("device", None) is None and isinstance(input_data["img"], torch.Tensor):
+ input_data["img"].to("cuda:0" if torch.cuda.is_available() else "cpu")
g.set_random_state(123)
res = g(input_data)
for key in res:
diff --git a/tests/test_rand_elasticd_3d.py b/tests/test_rand_elasticd_3d.py
index eaba06c9530..9db474861ed 100644
--- a/tests/test_rand_elasticd_3d.py
+++ b/tests/test_rand_elasticd_3d.py
@@ -141,6 +141,8 @@ class TestRand3DElasticd(unittest.TestCase):
def test_rand_3d_elasticd(self, input_param, input_data, expected_val):
g = Rand3DElasticd(**input_param)
g.set_random_state(123)
+ if input_param.get("device", None) is None and isinstance(input_data["img"], torch.Tensor):
+ input_data["img"].to("cuda:0" if torch.cuda.is_available() else "cpu")
res = g(input_data)
for key in res:
result = res[key]
diff --git a/tests/test_rand_gibbs_noised.py b/tests/test_rand_gibbs_noised.py
index 8b15fcc2678..23a7dd5fdb3 100644
--- a/tests/test_rand_gibbs_noised.py
+++ b/tests/test_rand_gibbs_noised.py
@@ -13,7 +13,6 @@
from copy import deepcopy
import numpy as np
-import torch
from parameterized import parameterized
from monai.data.synthetic import create_test_image_2d, create_test_image_3d
@@ -53,7 +52,7 @@ def test_0_prob(self, im_shape, input_type):
t = RandGibbsNoised(KEYS, 0.0, alpha)
out = t(data)
for k in KEYS:
- torch.testing.assert_allclose(data[k], out[k], rtol=1e-7, atol=0)
+ assert_allclose(data[k], out[k], rtol=1e-7, atol=0, type_test=False)
@parameterized.expand(TEST_CASES)
def test_same_result(self, im_shape, input_type):
@@ -93,7 +92,7 @@ def test_dict_matches(self, im_shape, input_type):
alpha = [0.5, 1.0]
t = RandGibbsNoised(KEYS, 1.0, alpha)
out = t(deepcopy(data))
- torch.testing.assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0)
+ assert_allclose(out[KEYS[0]], out[KEYS[1]], rtol=1e-7, atol=0, type_test=False)
@parameterized.expand(TEST_CASES)
def test_alpha(self, im_shape, input_type):
diff --git a/tests/test_rand_grid_patch.py b/tests/test_rand_grid_patch.py
index 3957dc1ce8f..417915fbabc 100644
--- a/tests/test_rand_grid_patch.py
+++ b/tests/test_rand_grid_patch.py
@@ -14,9 +14,10 @@
import numpy as np
from parameterized import parameterized
+from monai.data import MetaTensor, set_track_meta
from monai.transforms.spatial.array import RandGridPatch
from monai.utils import set_determinism
-from tests.utils import TEST_NDARRAYS, assert_allclose
+from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose
set_determinism(1234)
@@ -57,6 +58,25 @@
]
TEST_CASE_10 = [{"patch_size": (2, 2), "min_offset": 0, "max_offset": 0, "threshold": 50.0}, A, [A11]]
+TEST_CASE_MEAT_0 = [
+ {"patch_size": (2, 2)},
+ A,
+ [A11, A12, A21, A22],
+ [{"location": [0, 0]}, {"location": [0, 2]}, {"location": [2, 0]}, {"location": [2, 2]}],
+]
+
+TEST_CASE_MEAT_1 = [
+ {"patch_size": (2, 2)},
+ MetaTensor(x=A, meta={"path": "path/to/file"}),
+ [A11, A12, A21, A22],
+ [
+ {"location": [0, 0], "path": "path/to/file"},
+ {"location": [0, 2], "path": "path/to/file"},
+ {"location": [2, 0], "path": "path/to/file"},
+ {"location": [2, 2], "path": "path/to/file"},
+ ],
+]
+
TEST_SINGLE = []
for p in TEST_NDARRAYS:
TEST_SINGLE.append([p, *TEST_CASE_0])
@@ -78,10 +98,28 @@ def test_rand_grid_patch(self, in_type, input_parameters, image, expected):
input_image = in_type(image)
splitter = RandGridPatch(**input_parameters)
splitter.set_random_state(1234)
- output = list(splitter(input_image))
+ output = splitter(input_image)
self.assertEqual(len(output), len(expected))
for output_patch, expected_patch in zip(output, expected):
- assert_allclose(output_patch[0], expected_patch, type_test=False)
+ assert_allclose(output_patch, expected_patch, type_test=False)
+
+ @parameterized.expand([TEST_CASE_MEAT_0, TEST_CASE_MEAT_1])
+ @SkipIfBeforePyTorchVersion((1, 9, 1))
+ def test_rand_grid_patch_meta(self, input_parameters, image, expected, expected_meta):
+ set_track_meta(True)
+ splitter = RandGridPatch(**input_parameters)
+ splitter.set_random_state(1234)
+ output = splitter(image)
+ self.assertEqual(len(output), len(expected))
+ if "path" in expected_meta[0]:
+ self.assertTrue(output.meta["path"] == expected_meta[0]["path"])
+ for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):
+ assert_allclose(output_patch, expected_patch, type_test=False)
+ self.assertTrue(isinstance(output_patch, MetaTensor))
+ self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"])
+ self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:]))
+ if "path" in expected_meta[0]:
+ self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"])
if __name__ == "__main__":
diff --git a/tests/test_rand_grid_patchd.py b/tests/test_rand_grid_patchd.py
index 656fbd9e366..4f3ec3bb6ac 100644
--- a/tests/test_rand_grid_patchd.py
+++ b/tests/test_rand_grid_patchd.py
@@ -83,10 +83,10 @@ def test_rand_grid_patchd(self, in_type, input_parameters, image_dict, expected)
input_dict[k] = in_type(v)
splitter = RandGridPatchd(keys=image_key, **input_parameters)
splitter.set_random_state(1234)
- output = list(splitter(input_dict))
- self.assertEqual(len(output), len(expected))
- for output_patch, expected_patch in zip(output, expected):
- assert_allclose(output_patch[image_key], expected_patch, type_test=False)
+ output = splitter(input_dict)
+ self.assertEqual(len(output[image_key]), len(expected))
+ for output_patch, expected_patch in zip(output[image_key], expected):
+ assert_allclose(output_patch, expected_patch, type_test=False)
if __name__ == "__main__":
diff --git a/tests/test_rand_k_space_spike_noised.py b/tests/test_rand_k_space_spike_noised.py
index 156c95822fc..7f493ef2760 100644
--- a/tests/test_rand_k_space_spike_noised.py
+++ b/tests/test_rand_k_space_spike_noised.py
@@ -40,7 +40,7 @@ def get_data(im_shape, im_type):
create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d
ims = create_test_image(*im_shape, rad_max=20, noise_max=0.0, num_seg_classes=5)
ims = [im_type(im[None]) for im in ims]
- return {k: v for k, v in zip(KEYS, ims)}
+ return dict(zip(KEYS, ims))
@parameterized.expand(TESTS)
def test_same_result(self, im_shape, im_type):
diff --git a/tests/test_rand_lambda.py b/tests/test_rand_lambda.py
index c356406f613..cb5c57e9e46 100644
--- a/tests/test_rand_lambda.py
+++ b/tests/test_rand_lambda.py
@@ -71,12 +71,10 @@ def test_rand_lambdad_identity(self, t):
ret = tr(img)
self.check(tr, img, img_t, ret, expected)
- # prob = 0
tr = RandLambda(func=test_func, prob=0.0)
ret = tr(img)
self.check(tr, img, img_t, ret, expected=img)
- # prob = 0.5
trans = RandLambda(func=test_func, prob=0.5)
trans.set_random_state(seed=123)
ret = trans(img)
diff --git a/tests/test_rand_lambdad.py b/tests/test_rand_lambdad.py
index b181db50356..8bd7bbbfc85 100644
--- a/tests/test_rand_lambdad.py
+++ b/tests/test_rand_lambdad.py
@@ -61,12 +61,10 @@ def test_rand_lambdad_identity(self, t):
ret = tr(deepcopy(data))
self.check(tr, data, ret, expected)
- # prob = 0
tr = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.0)
ret = tr(deepcopy(data))
self.check(tr, data, ret, expected=data)
- # prob = 0.5
trans = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.5)
trans.set_random_state(seed=123)
ret = trans(deepcopy(data))
diff --git a/tests/test_rand_rician_noise.py b/tests/test_rand_rician_noise.py
index 896ae8b2e0a..9ee1a6ce822 100644
--- a/tests/test_rand_rician_noise.py
+++ b/tests/test_rand_rician_noise.py
@@ -32,6 +32,8 @@ def test_correct_results(self, _, in_type, mean, std):
rician_fn.set_random_state(seed)
im = in_type(self.imt)
noised = rician_fn(im)
+ if isinstance(im, torch.Tensor):
+ self.assertEqual(im.dtype, noised.dtype)
np.random.seed(seed)
np.random.random()
_std = np.random.uniform(0, std)
diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py
index bdee0474d0a..5d3a76a86d4 100644
--- a/tests/test_rand_rotate.py
+++ b/tests/test_rand_rotate.py
@@ -19,7 +19,13 @@
from monai.data import MetaTensor, set_track_meta
from monai.transforms import RandRotate
-from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion
+from tests.utils import (
+ TEST_NDARRAYS_ALL,
+ NumpyImageTestCase2D,
+ NumpyImageTestCase3D,
+ assert_allclose,
+ test_local_inversion,
+)
TEST_CASES_2D: List[Tuple] = []
for p in TEST_NDARRAYS_ALL:
@@ -111,7 +117,7 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode,
rotate_fn.set_random_state(243)
im = im_type(self.imt[0])
rotated = rotate_fn(im)
- torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0)
+ assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0)
test_local_inversion(rotate_fn, rotated, im)
set_track_meta(False)
@@ -121,5 +127,22 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode,
set_track_meta(True)
+class TestRandRotateDtype(NumpyImageTestCase2D):
+ @parameterized.expand(TEST_CASES_2D)
+ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners):
+ rotate_fn = RandRotate(
+ range_x=1.0,
+ prob=0.5,
+ keep_size=keep_size,
+ mode=mode,
+ padding_mode=padding_mode,
+ align_corners=align_corners,
+ dtype=np.float64,
+ )
+ im = im_type(self.imt[0])
+ rotated = rotate_fn(im)
+ self.assertEqual(rotated.dtype, torch.float32)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py
index 906977f3fa8..e19f8a513f2 100644
--- a/tests/test_rand_rotated.py
+++ b/tests/test_rand_rotated.py
@@ -28,7 +28,6 @@
TEST_CASES_2D.append((p, np.pi, False, "nearest", "zeros", True))
TEST_CASES_2D.append((p, (-np.pi / 4, 0), False, "nearest", "zeros", True))
-
TEST_CASES_3D: List[Tuple] = []
for p in TEST_NDARRAYS_ALL:
TEST_CASES_3D.append(
@@ -145,7 +144,7 @@ class TestRandRotated3D(NumpyImageTestCase3D):
@parameterized.expand(TEST_CASES_3D)
def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, align_corners, expected):
rotate_fn = RandRotated(
- "img",
+ ("img", "seg"),
range_x=x,
range_y=y,
range_z=z,
@@ -160,6 +159,10 @@ def test_correct_shapes(self, im_type, x, y, z, keep_size, mode, padding_mode, a
rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])})
np.testing.assert_allclose(rotated["img"].shape, expected)
+ rotate_fn.prob = 0.0
+ rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])})
+ self.assertEqual(rotated["seg"].dtype, torch.float32)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_rand_std_shift_intensity.py b/tests/test_rand_std_shift_intensity.py
index b26f5ef0962..a2345dca1d9 100644
--- a/tests/test_rand_std_shift_intensity.py
+++ b/tests/test_rand_std_shift_intensity.py
@@ -12,6 +12,7 @@
import unittest
import numpy as np
+import torch
from parameterized import parameterized
from monai.transforms import RandStdShiftIntensity
@@ -29,7 +30,10 @@ def test_value(self, p):
expected = p(self.imt + offset)
shifter = RandStdShiftIntensity(factors=1.0, prob=1.0)
shifter.set_random_state(seed=0)
- result = shifter(p(self.imt))
+ _imt = p(self.imt)
+ result = shifter(_imt)
+ if isinstance(_imt, torch.Tensor):
+ self.assertEqual(result.dtype, _imt.dtype)
assert_allclose(result, expected, atol=0, rtol=1e-5, type_test="tensor")
diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py
index 53913ce9874..2e1fcad4b2a 100644
--- a/tests/test_rand_weighted_crop.py
+++ b/tests/test_rand_weighted_crop.py
@@ -28,7 +28,6 @@ def get_data(ndim):
IMT_2D, SEG1_2D, SEGN_2D = get_data(ndim=2)
IMT_3D, SEG1_3D, SEGN_3D = get_data(ndim=3)
-
TESTS = []
for p in TEST_NDARRAYS_ALL:
for q in TEST_NDARRAYS_ALL:
diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py
index fc8280490fc..b34d3b04194 100644
--- a/tests/test_rand_zoom.py
+++ b/tests/test_rand_zoom.py
@@ -12,6 +12,7 @@
import unittest
import numpy as np
+import torch
from parameterized import parameterized
from scipy.ndimage import zoom as zoom_scipy
@@ -51,6 +52,8 @@ def test_keep_size(self):
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
zoomed = random_zoom(im)
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:]))
+ random_zoom.prob = 0.0
+ self.assertEqual(random_zoom(im).dtype, torch.float32)
@parameterized.expand(
[("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_mode", 0.9, 1.1, "s", ValueError)]
diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py
index b2ae40530a2..3a067e8a90c 100644
--- a/tests/test_rand_zoomd.py
+++ b/tests/test_rand_zoomd.py
@@ -12,6 +12,7 @@
import unittest
import numpy as np
+import torch
from parameterized import parameterized
from scipy.ndimage import zoom as zoom_scipy
@@ -58,6 +59,8 @@ def test_keep_size(self):
zoomed = random_zoom({key: im})
test_local_inversion(random_zoom, zoomed, {key: im}, key)
np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:])
+ random_zoom.prob = 0.0
+ self.assertEqual(random_zoom({key: p(self.imt[0])})[key].dtype, torch.float32)
@parameterized.expand(
[("no_min_zoom", None, 1.1, "bilinear", TypeError), ("invalid_order", 0.9, 1.1, "s", ValueError)]
diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py
new file mode 100644
index 00000000000..9f77d2cd5a3
--- /dev/null
+++ b/tests/test_randomizable_transform_type.py
@@ -0,0 +1,33 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from monai.transforms.transform import RandomizableTrait, RandomizableTransform
+
+
+class InheritsInterface(RandomizableTrait):
+ pass
+
+
+class InheritsImplementation(RandomizableTransform):
+ def __call__(self, data):
+ return data
+
+
+class TestRandomizableTransformType(unittest.TestCase):
+ def test_is_randomizable_transform_type(self):
+ inst = InheritsInterface()
+ self.assertIsInstance(inst, RandomizableTrait)
+
+ def test_set_random_state_randomizable_transform(self):
+ inst = InheritsImplementation()
+ inst.set_random_state(0)
diff --git a/tests/test_randtorchvisiond.py b/tests/test_randtorchvisiond.py
index 4fc1a1e6306..c0588581445 100644
--- a/tests/test_randtorchvisiond.py
+++ b/tests/test_randtorchvisiond.py
@@ -16,6 +16,7 @@
from monai.transforms import Randomizable, RandTorchVisiond
from monai.utils import set_determinism
+from tests.utils import assert_allclose
TEST_CASE_1 = [
{"keys": "img", "name": "ColorJitter"},
@@ -55,7 +56,7 @@ def test_value(self, input_param, input_data, expected_value):
transform = RandTorchVisiond(**input_param)
result = transform(input_data)
self.assertTrue(isinstance(transform, Randomizable))
- torch.testing.assert_allclose(result["img"], expected_value)
+ assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4)
if __name__ == "__main__":
diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py
index 18d89296d82..6621bf735ef 100644
--- a/tests/test_recon_net_utils.py
+++ b/tests/test_recon_net_utils.py
@@ -16,11 +16,16 @@
from monai.apps.reconstruction.networks.nets.utils import (
complex_normalize,
+ divisible_pad_t,
+ inverse_divisible_pad_t,
reshape_batch_channel_to_channel_dim,
reshape_channel_complex_to_last_dim,
reshape_channel_to_batch_dim,
reshape_complex_to_channel_dim,
+ sensitivity_map_expand,
+ sensitivity_map_reduce,
)
+from tests.utils import assert_allclose
# no need for checking devices, these functions don't change device format
# reshape test case
@@ -31,6 +36,15 @@
im_2d, im_3d = torch.randint(0, 3, [3, 4, 50, 70]).float(), torch.randint(0, 3, [3, 4, 50, 70, 80]).float()
TEST_NORMALIZE = [(im_2d,), (im_3d,)]
+# pad test case
+im_2d, im_3d = torch.ones([3, 4, 50, 70]), torch.ones([3, 4, 50, 70, 80])
+TEST_PAD = [(im_2d,), (im_3d,)]
+
+# test case for sensitivity map expansion/reduction
+ksp_2d, ksp_3d = torch.ones([3, 4, 50, 70, 2]), torch.ones([3, 4, 50, 70, 80, 2])
+sens_2d, sens_3d = torch.ones([3, 4, 50, 70, 2]), torch.ones([3, 4, 50, 70, 80, 2])
+TEST_SENS = [(ksp_2d, sens_2d), (ksp_3d, sens_3d)]
+
class TestReconNetUtils(unittest.TestCase):
@parameterized.expand(TEST_RESHAPE)
@@ -49,6 +63,18 @@ def test_complex_normalize(self, test_data):
result = result * std + mean
self.assertTrue((((result - test_data) ** 2).mean() ** 0.5).item() < 1e-5)
+ @parameterized.expand(TEST_PAD)
+ def test_pad(self, test_data):
+ result, padding_sizes = divisible_pad_t(test_data, k=16)
+ result = inverse_divisible_pad_t(result, padding_sizes)
+ assert_allclose(result, test_data)
+
+ @parameterized.expand(TEST_SENS)
+ def test_sens_expand_reduce(self, test_data, sens):
+ result = sensitivity_map_reduce(test_data, sens)
+ result = sensitivity_map_expand(result, sens)
+ self.assertEqual(result.shape, test_data.shape)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_reference_based_normalize_intensity.py b/tests/test_reference_based_normalize_intensity.py
index 0f5fa7d6279..01811e59072 100644
--- a/tests/test_reference_based_normalize_intensity.py
+++ b/tests/test_reference_based_normalize_intensity.py
@@ -24,7 +24,6 @@
# which focuses on (1) automatic target normalization and (2) mean-std
# return values
-
TESTS = []
for p in TEST_NDARRAYS_NO_META_TENSOR:
TESTS.append(
diff --git a/tests/test_reference_based_spatial_cropd.py b/tests/test_reference_based_spatial_cropd.py
index d1f6230da4c..ab5573044d3 100644
--- a/tests/test_reference_based_spatial_cropd.py
+++ b/tests/test_reference_based_spatial_cropd.py
@@ -22,7 +22,6 @@
# here, we test TargetBasedSpatialCropd's functionality
# which focuses on automatic input crop based on target image's shape.
-
TESTS = []
for p in TEST_NDARRAYS:
# 2D
diff --git a/tests/test_regunet.py b/tests/test_regunet.py
index e37ca49538f..04f971d2eb6 100644
--- a/tests/test_regunet.py
+++ b/tests/test_regunet.py
@@ -20,7 +20,6 @@
device = "cuda" if torch.cuda.is_available() else "cpu"
-
TEST_CASE_REGUNET_2D = [
[
{
diff --git a/tests/test_remove_small_objects.py b/tests/test_remove_small_objects.py
new file mode 100644
index 00000000000..7130d607396
--- /dev/null
+++ b/tests/test_remove_small_objects.py
@@ -0,0 +1,74 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from typing import List, Tuple
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.data.meta_tensor import MetaTensor
+from monai.transforms.post.array import RemoveSmallObjects
+from monai.transforms.post.dictionary import RemoveSmallObjectsd
+from monai.utils import optional_import
+from tests.utils import TEST_NDARRAYS, SkipIfNoModule, assert_allclose
+
+morphology, has_morphology = optional_import("skimage.morphology")
+
+TEST_ZEROS = np.zeros((1, 9, 8, 7))
+TEST_ONES = np.ones((3, 7, 8, 9))
+
+TEST_INPUT1 = np.array([[[0, 0, 2, 1, 0], [1, 1, 1, 2, 0], [1, 1, 1, 0, 1]]])
+
+TEST_OUTPUT1 = np.array([[[0, 0, 2, 1, 0], [1, 1, 1, 2, 0], [1, 1, 1, 0, 0]]])
+
+TESTS: List[Tuple] = []
+for dtype in (int, float):
+ for p in TEST_NDARRAYS:
+ TESTS.append((dtype, p, TEST_ZEROS, None))
+ TESTS.append((dtype, p, TEST_ONES, None))
+ TESTS.append((dtype, p, TEST_INPUT1, None, {"min_size": 6}))
+ TESTS.append((dtype, p, TEST_INPUT1, None, {"min_size": 7, "connectivity": 2}))
+ # for non-independent channels, the twos should stay
+ TESTS.append((dtype, p, TEST_INPUT1, TEST_OUTPUT1, {"min_size": 2, "independent_channels": False}))
+
+
+@SkipIfNoModule("skimage.morphology")
+class TestRemoveSmallObjects(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_remove_small_objects(self, dtype, im_type, lbl, expected, params=None):
+ params = params or {}
+ if expected is None:
+ dtype = bool if len(np.unique(lbl)) == 1 else int
+ expected = morphology.remove_small_objects(lbl.astype(dtype), **params)
+ expected = im_type(expected, dtype=dtype)
+ lbl = im_type(lbl, dtype=dtype)
+ lbl_clean = RemoveSmallObjects(**params)(lbl)
+ assert_allclose(lbl_clean, expected, device_test=True)
+ if isinstance(lbl, MetaTensor):
+ assert_allclose(lbl.affine, lbl_clean.affine)
+
+ @parameterized.expand(TESTS)
+ def test_remove_small_objects_dict(self, dtype, im_type, lbl, expected, params=None):
+ params = params or {}
+ if expected is None:
+ dtype = bool if len(np.unique(lbl)) == 1 else int
+ expected = morphology.remove_small_objects(lbl.astype(dtype), **params)
+ expected = im_type(expected, dtype=dtype)
+ lbl = im_type(lbl, dtype=dtype)
+ lbl_clean = RemoveSmallObjectsd("lbl", **params)({"lbl": lbl})["lbl"]
+ assert_allclose(lbl_clean, expected, device_test=True)
+ if isinstance(lbl, MetaTensor):
+ assert_allclose(lbl.affine, lbl_clean.affine)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_resample_backends.py b/tests/test_resample_backends.py
new file mode 100644
index 00000000000..6d231183a94
--- /dev/null
+++ b/tests/test_resample_backends.py
@@ -0,0 +1,63 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.config import USE_COMPILED
+from monai.data import MetaTensor
+from monai.transforms import Resample
+from monai.transforms.utils import create_grid
+from monai.utils import GridSampleMode, GridSamplePadMode, NdimageMode, SplineMode, convert_to_numpy
+from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, is_tf32_env
+
+_rtol = 1e-3 if is_tf32_env() else 1e-4
+
+TEST_IDENTITY = []
+for interp in GridSampleMode if not USE_COMPILED else ("nearest", "bilinear"): # type: ignore
+ for pad in GridSamplePadMode:
+ for p in (np.float32, np.float64):
+ for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
+ TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 4)])
+ if interp != "bicubic":
+ TEST_IDENTITY.append([dict(device=device), p, interp, pad, (1, 3, 5, 8)])
+for interp_s in SplineMode if not USE_COMPILED else []: # type: ignore
+ for pad_s in NdimageMode:
+ for p_s in (int, float, np.float32, np.float64):
+ for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]:
+ TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 20, 21)])
+ TEST_IDENTITY.append([dict(device=device), p_s, interp_s, pad_s, (1, 21, 23, 24)])
+
+
+@SkipIfBeforePyTorchVersion((1, 9, 1))
+class TestResampleBackends(unittest.TestCase):
+ @parameterized.expand(TEST_IDENTITY)
+ def test_resample_identity(self, input_param, im_type, interp, padding, input_shape):
+ """test resampling of an identity grid with padding 2, im_type, interp, padding, input_shape"""
+ xform = Resample(dtype=im_type, **input_param)
+ n_elem = np.prod(input_shape)
+ img = convert_to_numpy(np.arange(n_elem).reshape(input_shape), dtype=im_type)
+ grid = create_grid(input_shape[1:], homogeneous=True, backend="numpy")
+ grid_p = np.stack([np.pad(g, 2, "constant") for g in grid]) # testing pad
+ output = xform(img=img, grid=grid_p, mode=interp, padding_mode=padding)
+ self.assertTrue(not torch.any(torch.isinf(output) | torch.isnan(output)))
+ self.assertIsInstance(output, MetaTensor)
+ slices = [slice(None)]
+ slices.extend([slice(2, -2) for _ in img.shape[1:]])
+ output_c = output[slices]
+ assert_allclose(output_c, img, rtol=_rtol, atol=1e-3, type_test="tensor")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_resample_to_match.py b/tests/test_resample_to_match.py
index f1d58e6379c..30df565a263 100644
--- a/tests/test_resample_to_match.py
+++ b/tests/test_resample_to_match.py
@@ -24,8 +24,11 @@
from monai.data.image_reader import ITKReader, NibabelReader
from monai.data.image_writer import ITKWriter
from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ResampleToMatch, SaveImaged
+from monai.utils import optional_import
from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config
+_, has_itk = optional_import("itk", allow_namespace_pkg=True)
+
TEST_CASES = ["itkreader", "nibabelreader"]
@@ -36,6 +39,7 @@ def get_rand_fname(len=10, suffix=".nii.gz"):
return out
+@unittest.skipUnless(has_itk, "itk not installed")
class TestResampleToMatch(unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -60,8 +64,6 @@ def test_correct(self, reader, writer):
loader = Compose([LoadImaged(("im1", "im2"), reader=reader), EnsureChannelFirstd(("im1", "im2"))])
data = loader({"im1": self.fnames[0], "im2": self.fnames[1]})
- with self.assertRaises(ValueError):
- ResampleToMatch(mode=None)(img=data["im2"], img_dst=data["im1"])
im_mod = ResampleToMatch()(data["im2"], data["im1"])
saver = SaveImaged(
"im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, writer=writer, resample=False
diff --git a/tests/test_resize.py b/tests/test_resize.py
index 8927b5dba5f..b755bb3faf7 100644
--- a/tests/test_resize.py
+++ b/tests/test_resize.py
@@ -74,8 +74,6 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing):
im = p(self.imt[0])
out = resize(im)
if isinstance(im, MetaTensor):
- if not out.applied_operations:
- return # skipped because good shape
im_inv = resize.inverse(out)
self.assertTrue(not im_inv.applied_operations)
assert_allclose(im_inv.shape, im.shape)
diff --git a/tests/test_resized.py b/tests/test_resized.py
index b8db6663575..a9da604b15c 100644
--- a/tests/test_resized.py
+++ b/tests/test_resized.py
@@ -13,23 +13,44 @@
import numpy as np
import skimage.transform
+import torch
from parameterized import parameterized
from monai.data import MetaTensor, set_track_meta
-from monai.transforms import Resized
+from monai.transforms import Invertd, Resize, Resized
from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion
TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)]
-TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 10, 15)]
+TEST_CASE_1 = [
+ {"keys": "img", "spatial_size": 15, "mode": "area", "anti_aliasing": True, "anti_aliasing_sigma": None},
+ (6, 10, 15),
+]
-TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (2, 4, 6)]
+TEST_CASE_2 = [
+ {"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True, "anti_aliasing_sigma": 2.0},
+ (2, 4, 6),
+]
TEST_CASE_3 = [
- {"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]},
+ {
+ "keys": ["img", "label"],
+ "spatial_size": 6,
+ "mode": ["trilinear", "nearest"],
+ "align_corners": [True, None],
+ "anti_aliasing": [False, True],
+ "anti_aliasing_sigma": (None, 2.0),
+ },
(2, 4, 6),
]
+TEST_CORRECT_CASES = [
+ ((32, -1), "area", False),
+ ((64, 64), "area", True),
+ ((32, 32, 32), "area", True),
+ ((256, 256), "bilinear", False),
+]
+
class TestResized(NumpyImageTestCase2D):
def test_invalid_inputs(self):
@@ -41,9 +62,9 @@ def test_invalid_inputs(self):
resize = Resized(keys="img", spatial_size=(128,), mode="order")
resize({"img": self.imt[0]})
- @parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")])
- def test_correct_results(self, spatial_size, mode):
- resize = Resized("img", spatial_size, mode=mode)
+ @parameterized.expand(TEST_CORRECT_CASES)
+ def test_correct_results(self, spatial_size, mode, anti_aliasing):
+ resize = Resized("img", spatial_size, mode=mode, anti_aliasing=anti_aliasing)
_order = 0
if mode.endswith("linear"):
_order = 1
@@ -51,7 +72,7 @@ def test_correct_results(self, spatial_size, mode):
spatial_size = (32, 64)
expected = [
skimage.transform.resize(
- channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=False
+ channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing
)
for channel in self.imt[0]
]
@@ -61,7 +82,7 @@ def test_correct_results(self, spatial_size, mode):
im = p(self.imt[0])
out = resize({"img": im})
test_local_inversion(resize, out, {"img": im}, "img")
- assert_allclose(out["img"], expected, type_test=False, atol=0.9)
+ assert_allclose(out["img"], expected, type_test=False, atol=1.0)
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_longest_shape(self, input_param, expected_shape):
@@ -80,6 +101,30 @@ def test_longest_shape(self, input_param, expected_shape):
np.testing.assert_allclose(result["img"].shape[1:], expected_shape)
set_track_meta(True)
+ def test_identical_spatial(self):
+ test_input = {"X": np.ones((1, 10, 16, 17))}
+ xform = Resized("X", (-1, 16, 17))
+ out = xform(test_input)
+ out["Y"] = 2 * out["X"]
+ transform_inverse = Invertd(keys="Y", transform=xform, orig_keys="X")
+ assert_allclose(transform_inverse(out)["Y"].array, np.ones((1, 10, 16, 17)) * 2)
+
+ def test_consistent_resize(self):
+ spatial_size = (16, 16, 16)
+ rescaler_1 = Resize(spatial_size=spatial_size, anti_aliasing=True, anti_aliasing_sigma=(0.5, 1.0, 2.0))
+ rescaler_2 = Resize(spatial_size=spatial_size, anti_aliasing=True, anti_aliasing_sigma=None)
+ rescaler_dict = Resized(
+ keys=["img1", "img2"],
+ spatial_size=spatial_size,
+ anti_aliasing=(True, True),
+ anti_aliasing_sigma=[(0.5, 1.0, 2.0), None],
+ )
+ test_input_1 = torch.randn([3, 32, 32, 32])
+ test_input_2 = torch.randn([3, 32, 32, 32])
+ test_input_dict = {"img1": test_input_1, "img2": test_input_2}
+ assert_allclose(rescaler_1(test_input_1), rescaler_dict(test_input_dict)["img1"])
+ assert_allclose(rescaler_2(test_input_2), rescaler_dict(test_input_dict)["img2"])
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_resnet.py b/tests/test_resnet.py
index 88499f78d0c..ae05f362102 100644
--- a/tests/test_resnet.py
+++ b/tests/test_resnet.py
@@ -28,7 +28,6 @@
else:
torchvision, has_torchvision = optional_import("torchvision")
-
device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
diff --git a/tests/test_retinanet.py b/tests/test_retinanet.py
index 3c136a4cf2d..f067e829629 100644
--- a/tests/test_retinanet.py
+++ b/tests/test_retinanet.py
@@ -22,7 +22,6 @@
_, has_torchvision = optional_import("torchvision")
-
device = "cuda" if torch.cuda.is_available() else "cpu"
num_anchors = 7
diff --git a/tests/test_retinanet_detector.py b/tests/test_retinanet_detector.py
index 99a70fb5fa3..243828432df 100644
--- a/tests/test_retinanet_detector.py
+++ b/tests/test_retinanet_detector.py
@@ -23,7 +23,6 @@
_, has_torchvision = optional_import("torchvision")
-
num_anchors = 7
TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
diff --git a/tests/test_save_image.py b/tests/test_save_image.py
index 6591283c22d..1f4039763e2 100644
--- a/tests/test_save_image.py
+++ b/tests/test_save_image.py
@@ -18,6 +18,9 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import SaveImage
+from monai.utils import optional_import
+
+_, has_itk = optional_import("itk", allow_namespace_pkg=True)
TEST_CASE_1 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nii.gz"}, ".nii.gz", False]
@@ -33,6 +36,7 @@
]
+@unittest.skipUnless(has_itk, "itk not installed")
class TestSaveImage(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_saved_content(self, test_data, meta_data, output_ext, resample):
diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py
index 96b6fb1626c..4b079b73fd9 100644
--- a/tests/test_save_imaged.py
+++ b/tests/test_save_imaged.py
@@ -18,6 +18,9 @@
from monai.data.meta_tensor import MetaTensor
from monai.transforms import SaveImaged
+from monai.utils import optional_import
+
+_, has_itk = optional_import("itk", allow_namespace_pkg=True)
TEST_CASE_1 = [
{"img": MetaTensor(torch.randint(0, 255, (1, 2, 3, 4)), meta={"filename_or_obj": "testfile0.nii.gz"})},
@@ -44,6 +47,7 @@
]
+@unittest.skipUnless(has_itk, "itk not installed")
class TestSaveImaged(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_saved_content(self, test_data, output_ext, resample):
diff --git a/tests/test_savitzky_golay_smoothd.py b/tests/test_savitzky_golay_smoothd.py
index 6f0b33f5333..730fdeeef24 100644
--- a/tests/test_savitzky_golay_smoothd.py
+++ b/tests/test_savitzky_golay_smoothd.py
@@ -12,11 +12,10 @@
import unittest
import numpy as np
-import torch
from parameterized import parameterized
from monai.transforms import SavitzkyGolaySmoothd
-from tests.utils import TEST_NDARRAYS
+from tests.utils import TEST_NDARRAYS, assert_allclose
# Zero-padding trivial tests
@@ -65,7 +64,7 @@ class TestSavitzkyGolaySmoothd(unittest.TestCase):
def test_value(self, arguments, image, expected_data, atol):
for p in TEST_NDARRAYS:
result = SavitzkyGolaySmoothd(**arguments)({"img": p(image.astype(np.float32))})["img"]
- torch.testing.assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol)
+ assert_allclose(result, p(expected_data.astype(np.float32)), rtol=1e-4, atol=atol, type_test=False)
if __name__ == "__main__":
diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py
new file mode 100644
index 00000000000..b9a5d873dc1
--- /dev/null
+++ b/tests/test_segresnet_ds.py
@@ -0,0 +1,132 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import SegResNetDS
+from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+TEST_CASE_SEGRESNET_DS = []
+for spatial_dims in range(2, 4):
+ for init_filters in [8, 16]:
+ for act in ["relu", "leakyrelu"]:
+ for norm in ["BATCH", ("instance", {"affine": True})]:
+ for upsample_mode in ["deconv", "nontrainable"]:
+ test_case = [
+ {
+ "spatial_dims": spatial_dims,
+ "init_filters": init_filters,
+ "act": act,
+ "norm": norm,
+ "upsample_mode": upsample_mode,
+ },
+ (2, 1, *([16] * spatial_dims)),
+ (2, 2, *([16] * spatial_dims)),
+ ]
+ TEST_CASE_SEGRESNET_DS.append(test_case)
+
+TEST_CASE_SEGRESNET_DS2 = []
+for spatial_dims in range(2, 4):
+ for out_channels in [1, 2]:
+ for dsdepth in [1, 2, 3]:
+ test_case = [
+ {"spatial_dims": spatial_dims, "init_filters": 8, "out_channels": out_channels, "dsdepth": dsdepth},
+ (2, 1, *([16] * spatial_dims)),
+ (2, out_channels, *([16] * spatial_dims)),
+ ]
+ TEST_CASE_SEGRESNET_DS2.append(test_case)
+
+TEST_CASE_SEGRESNET_DS3 = [
+ ({"init_filters": 8, "dsdepth": 2, "resolution": None}, (2, 1, 16, 16, 16), ((2, 2, 16, 16, 16), (2, 2, 8, 8, 8))),
+ (
+ {"init_filters": 8, "dsdepth": 3, "resolution": None},
+ (2, 1, 16, 16, 16),
+ ((2, 2, 16, 16, 16), (2, 2, 8, 8, 8), (2, 2, 4, 4, 4)),
+ ),
+ (
+ {"init_filters": 8, "dsdepth": 3, "resolution": [1, 1, 5]},
+ (2, 1, 16, 16, 16),
+ ((2, 2, 16, 16, 16), (2, 2, 8, 8, 16), (2, 2, 4, 4, 16)),
+ ),
+ (
+ {"init_filters": 8, "dsdepth": 3, "resolution": [1, 2, 5]},
+ (2, 1, 16, 16, 16),
+ ((2, 2, 16, 16, 16), (2, 2, 8, 8, 16), (2, 2, 4, 8, 16)),
+ ),
+]
+
+
+class TestResNetDS(unittest.TestCase):
+ @parameterized.expand(TEST_CASE_SEGRESNET_DS)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ net = SegResNetDS(**input_param).to(device)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ @parameterized.expand(TEST_CASE_SEGRESNET_DS2)
+ def test_shape2(self, input_param, input_shape, expected_shape):
+
+ dsdepth = input_param.get("dsdepth", 1)
+ net = SegResNetDS(**input_param).to(device)
+
+ net.train()
+ result = net(torch.randn(input_shape).to(device))
+ if dsdepth > 1:
+ assert isinstance(result, list)
+ self.assertEqual(dsdepth, len(result))
+ for i in range(dsdepth):
+ self.assertEqual(
+ result[i].shape,
+ expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]),
+ msg=str(input_param),
+ )
+ else:
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ net.eval()
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ @parameterized.expand(TEST_CASE_SEGRESNET_DS3)
+ def test_shape3(self, input_param, input_shape, expected_shapes):
+
+ dsdepth = input_param.get("dsdepth", 1)
+ net = SegResNetDS(**input_param).to(device)
+
+ net.train()
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, list)
+ self.assertEqual(dsdepth, len(result))
+ for i in range(dsdepth):
+ self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param))
+
+ def test_ill_arg(self):
+ with self.assertRaises(ValueError):
+ SegResNetDS(spatial_dims=4)
+
+ @SkipIfBeforePyTorchVersion((1, 10))
+ def test_script(self):
+ input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0]
+ net = SegResNetDS(**input_param)
+ test_data = torch.randn(input_shape)
+ test_script_save(net, test_data)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_senet.py b/tests/test_senet.py
index 34f140638ea..b0d8ac0c0af 100644
--- a/tests/test_senet.py
+++ b/tests/test_senet.py
@@ -30,10 +30,8 @@
else:
pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels")
-
device = "cuda" if torch.cuda.is_available() else "cpu"
-
NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2}
TEST_CASE_1 = [SENet154, NET_ARGS]
TEST_CASE_2 = [SEResNet50, NET_ARGS]
diff --git a/tests/test_separable_filter.py b/tests/test_separable_filter.py
index e152ad2c2ba..e6838e2f9b4 100644
--- a/tests/test_separable_filter.py
+++ b/tests/test_separable_filter.py
@@ -64,7 +64,6 @@ def test_3d(self):
],
]
)
- expected = expected
# testing shapes
k = torch.tensor([1, 1, 1])
for kernel in (k, [k] * 3):
diff --git a/tests/test_shuffle_buffer.py b/tests/test_shuffle_buffer.py
index 40012fbf932..8067eee2bd9 100644
--- a/tests/test_shuffle_buffer.py
+++ b/tests/test_shuffle_buffer.py
@@ -16,18 +16,26 @@
from monai.data import DataLoader, ShuffleBuffer
from monai.utils import convert_data_type
+from tests.utils import SkipIfBeforePyTorchVersion
+@SkipIfBeforePyTorchVersion((1, 12))
class TestShuffleBuffer(unittest.TestCase):
def test_shape(self):
buffer = ShuffleBuffer([1, 2, 3, 4], seed=0)
num_workers = 2 if sys.platform == "linux" else 0
- dataloader = DataLoader(dataset=buffer, batch_size=2, num_workers=num_workers)
+ dataloader = DataLoader(
+ dataset=buffer, batch_size=2, num_workers=num_workers, persistent_workers=num_workers > 0
+ )
output = [convert_data_type(x, np.ndarray)[0] for x in dataloader]
+ buffer.seed += 1
+ output2 = [convert_data_type(x, np.ndarray)[0] for x in dataloader] # test repeating
if num_workers == 0:
np.testing.assert_allclose(output, [[2, 1], [3, 4]])
+ np.testing.assert_allclose(output2, [[3, 1], [2, 4]])
else: # multiprocess shuffle
- np.testing.assert_allclose(output, [[2, 3], [1, 4]])
+ np.testing.assert_allclose(output, [[2, 3], [1, 4]], err_msg=f"seed {buffer.seed}")
+ np.testing.assert_allclose(output2, [[1, 4], [2, 3]], err_msg=f"seed {buffer.seed}")
if __name__ == "__main__":
diff --git a/tests/test_signal_continuouswavelet.py b/tests/test_signal_continuouswavelet.py
new file mode 100644
index 00000000000..f8f028aec91
--- /dev/null
+++ b/tests/test_signal_continuouswavelet.py
@@ -0,0 +1,40 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalContinuousWavelet
+from monai.utils import optional_import
+
+_, has_pywt = optional_import("pywt")
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [("mexh", 150, 500)]
+EXPECTED_RESULTS = [(6, 150, 2000)]
+
+
+@skipUnless(has_pywt, "pywt required")
+class TestSignalContinousWavelet(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, type, length, frequency):
+ self.assertIsInstance(SignalContinuousWavelet(type, length, frequency), SignalContinuousWavelet)
+ sig = np.load(TEST_SIGNAL)
+ cwt = SignalContinuousWavelet(type, length, frequency)
+ cwtsignal = cwt(sig)
+ self.assertEqual(cwtsignal.shape, EXPECTED_RESULTS[0])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py
new file mode 100644
index 00000000000..388426bc959
--- /dev/null
+++ b/tests/test_signal_fillempty.py
@@ -0,0 +1,48 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+import torch
+
+from monai.transforms import SignalFillEmpty
+from monai.utils.type_conversion import convert_to_tensor
+from tests.utils import SkipIfBeforePyTorchVersion
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+
+
+@SkipIfBeforePyTorchVersion((1, 9))
+class TestSignalFillEmptyNumpy(unittest.TestCase):
+ def test_correct_parameters_multi_channels(self):
+ self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
+ sig = np.load(TEST_SIGNAL)
+ sig[:, 123] = np.NAN
+ fillempty = SignalFillEmpty(replacement=0.0)
+ fillemptysignal = fillempty(sig)
+ self.assertTrue(not np.isnan(fillemptysignal.any()))
+
+
+@SkipIfBeforePyTorchVersion((1, 9))
+class TestSignalFillEmptyTorch(unittest.TestCase):
+ def test_correct_parameters_multi_channels(self):
+ self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ sig[:, 123] = convert_to_tensor(np.NAN)
+ fillempty = SignalFillEmpty(replacement=0.0)
+ fillemptysignal = fillempty(sig)
+ self.assertTrue(not torch.isnan(fillemptysignal.any()))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_add_gaussiannoise.py b/tests/test_signal_rand_add_gaussiannoise.py
new file mode 100644
index 00000000000..dbaf716c4b6
--- /dev/null
+++ b/tests/test_signal_rand_add_gaussiannoise.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandAddGaussianNoise
+from monai.utils.type_conversion import convert_to_tensor
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 0.02],)]
+
+
+class TestSignalRandAddGaussianNoiseNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)
+ sig = np.load(TEST_SIGNAL)
+ gaussian = SignalRandAddGaussianNoise(boundaries)
+ gaussiansignal = gaussian(sig)
+ self.assertEqual(gaussiansignal.shape[1], sig.shape[1])
+
+
+class TestSignalRandAddGaussianNoiseTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandAddGaussianNoise(boundaries), SignalRandAddGaussianNoise)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ gaussian = SignalRandAddGaussianNoise(boundaries)
+ gaussiansignal = gaussian(sig)
+ self.assertEqual(gaussiansignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_add_sine.py b/tests/test_signal_rand_add_sine.py
new file mode 100644
index 00000000000..5cb63f1496b
--- /dev/null
+++ b/tests/test_signal_rand_add_sine.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandAddSine
+from monai.utils.type_conversion import convert_to_tensor
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 1.0], [0.0, 0.5]), ([0.0, 1.0], [0.01, 0.1])]
+
+
+class TestSignalRandAddSineNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, freqs):
+ self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)
+ sig = np.load(TEST_SIGNAL)
+ sine = SignalRandAddSine(boundaries, freqs)
+ sinesignal = sine(sig)
+ self.assertEqual(sinesignal.shape[1], sig.shape[1])
+
+
+class TestSignalRandAddSineTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, freqs):
+ self.assertIsInstance(SignalRandAddSine(boundaries, freqs), SignalRandAddSine)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ sine = SignalRandAddSine(boundaries, freqs)
+ sinesignal = sine(sig)
+ self.assertEqual(sinesignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_add_sine_partial.py b/tests/test_signal_rand_add_sine_partial.py
new file mode 100644
index 00000000000..c04e6b138c7
--- /dev/null
+++ b/tests/test_signal_rand_add_sine_partial.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandAddSinePartial
+from monai.utils.type_conversion import convert_to_tensor
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 1.0], [0.1, 0.6], [0.0, 0.4])]
+
+
+class TestSignalRandAddSinePartialNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
+ self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)
+ sig = np.load(TEST_SIGNAL)
+ partialsine = SignalRandAddSinePartial(boundaries, frequencies, fraction)
+ partialsinesignal = partialsine(sig)
+ self.assertEqual(partialsinesignal.shape[1], sig.shape[1])
+
+
+class TestSignalRandAddSinePartialTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
+ self.assertIsInstance(SignalRandAddSinePartial(boundaries, frequencies, fraction), SignalRandAddSinePartial)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ partialsine = SignalRandAddSinePartial(boundaries, frequencies, fraction)
+ partialsinesignal = partialsine(sig)
+ self.assertEqual(partialsinesignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_add_squarepulse.py b/tests/test_signal_rand_add_squarepulse.py
new file mode 100644
index 00000000000..6c96f695778
--- /dev/null
+++ b/tests/test_signal_rand_add_squarepulse.py
@@ -0,0 +1,54 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandAddSquarePulse
+from monai.utils import optional_import
+from monai.utils.type_conversion import convert_to_tensor
+from tests.utils import SkipIfBeforePyTorchVersion
+
+_, has_scipy = optional_import("scipy")
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 1.0], [0.001, 0.2])]
+
+
+@skipUnless(has_scipy, "scipy required")
+@SkipIfBeforePyTorchVersion((1, 10, 1))
+class TestSignalRandAddSquarePulseNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies):
+ self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)
+ sig = np.load(TEST_SIGNAL)
+ squared = SignalRandAddSquarePulse(boundaries, frequencies)
+ squaredsignal = squared(sig)
+ self.assertEqual(squaredsignal.shape[1], sig.shape[1])
+
+
+@skipUnless(has_scipy, "scipy required")
+@SkipIfBeforePyTorchVersion((1, 10, 1))
+class TestSignalRandAddSquarePulseTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies):
+ self.assertIsInstance(SignalRandAddSquarePulse(boundaries, frequencies), SignalRandAddSquarePulse)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ squared = SignalRandAddSquarePulse(boundaries, frequencies)
+ squaredsignal = squared(sig)
+ self.assertEqual(squaredsignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_add_squarepulse_partial.py b/tests/test_signal_rand_add_squarepulse_partial.py
new file mode 100644
index 00000000000..dd7aeae7935
--- /dev/null
+++ b/tests/test_signal_rand_add_squarepulse_partial.py
@@ -0,0 +1,58 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandAddSquarePulsePartial
+from monai.utils import optional_import
+from monai.utils.type_conversion import convert_to_tensor
+from tests.utils import SkipIfBeforePyTorchVersion
+
+_, has_scipy = optional_import("scipy")
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 1.0], [0.001, 0.2], [0.0, 0.4])]
+
+
+@skipUnless(has_scipy, "scipy required")
+@SkipIfBeforePyTorchVersion((1, 10, 1))
+class TestSignalRandAddSquarePulsePartialNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
+ self.assertIsInstance(
+ SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction), SignalRandAddSquarePulsePartial
+ )
+ sig = np.load(TEST_SIGNAL)
+ partialsquare = SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction)
+ partialsquaresignal = partialsquare(sig)
+ self.assertEqual(partialsquaresignal.shape[1], sig.shape[1])
+
+
+@skipUnless(has_scipy, "scipy required")
+@SkipIfBeforePyTorchVersion((1, 10, 1))
+class TestSignalRandAddSquarePulsePartialTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries, frequencies, fraction):
+ self.assertIsInstance(
+ SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction), SignalRandAddSquarePulsePartial
+ )
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ partialsquare = SignalRandAddSquarePulsePartial(boundaries, frequencies, fraction)
+ partialsquaresignal = partialsquare(sig)
+ self.assertEqual(partialsquaresignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_drop.py b/tests/test_signal_rand_drop.py
new file mode 100644
index 00000000000..4235ae6d877
--- /dev/null
+++ b/tests/test_signal_rand_drop.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandDrop
+from monai.utils.type_conversion import convert_to_tensor
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([0.0, 1.0],), ([0.01, 0.1],)]
+
+
+class TestSignalRandDropNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)
+ sig = np.load(TEST_SIGNAL)
+ droped = SignalRandDrop(boundaries)
+ dropedsignal = droped(sig)
+ self.assertEqual(dropedsignal.shape[1], sig.shape[1])
+
+
+class TestSignalRandDropTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandDrop(boundaries), SignalRandDrop)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ droped = SignalRandDrop(boundaries)
+ dropedsignal = droped(sig)
+ self.assertEqual(dropedsignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_scale.py b/tests/test_signal_rand_scale.py
new file mode 100644
index 00000000000..2ac708ef190
--- /dev/null
+++ b/tests/test_signal_rand_scale.py
@@ -0,0 +1,46 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import SignalRandScale
+from monai.utils.type_conversion import convert_to_tensor
+
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [([-1.0, 1.0],), ([0.01, 0.1],)]
+
+
+class TestSignalRandScaleNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)
+ sig = np.load(TEST_SIGNAL)
+ scaled = SignalRandScale(boundaries)
+ scaledsignal = scaled(sig)
+ self.assertEqual(scaledsignal.shape[1], sig.shape[1])
+
+
+class TestSignalRandScaleTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, boundaries):
+ self.assertIsInstance(SignalRandScale(boundaries), SignalRandScale)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ scaled = SignalRandScale(boundaries)
+ scaledsignal = scaled(sig)
+ self.assertEqual(scaledsignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_rand_shift.py b/tests/test_signal_rand_shift.py
new file mode 100644
index 00000000000..402cd433f8b
--- /dev/null
+++ b/tests/test_signal_rand_shift.py
@@ -0,0 +1,51 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms.signal.array import SignalRandShift
+from monai.utils import optional_import
+from monai.utils.type_conversion import convert_to_tensor
+
+_, has_scipy = optional_import("scipy")
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [("wrap", 0.0, [-1.0, 1.0])]
+
+
+@skipUnless(has_scipy, "scipy required")
+class TestSignalRandShiftNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, mode, filling, boundaries):
+ self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)
+ sig = np.load(TEST_SIGNAL)
+ shifted = SignalRandShift(mode, filling, boundaries)
+ shiftedsignal = shifted(sig)
+ self.assertEqual(shiftedsignal.shape[1], sig.shape[1])
+
+
+@skipUnless(has_scipy, "scipy required")
+class TestSignalRandShiftTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, mode, filling, boundaries):
+ self.assertIsInstance(SignalRandShift(mode, filling, boundaries), SignalRandShift)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ shifted = SignalRandShift(mode, filling, boundaries)
+ shiftedsignal = shifted(sig)
+ self.assertEqual(shiftedsignal.shape[1], sig.shape[1])
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_signal_remove_frequency.py b/tests/test_signal_remove_frequency.py
new file mode 100644
index 00000000000..fa70c4f795e
--- /dev/null
+++ b/tests/test_signal_remove_frequency.py
@@ -0,0 +1,67 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# you may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms import SignalRemoveFrequency
+from monai.utils import optional_import
+from monai.utils.type_conversion import convert_to_tensor
+
+_, has_scipy = optional_import("scipy")
+_, has_torchaudio = optional_import("torchaudio")
+TEST_SIGNAL = os.path.join(os.path.dirname(__file__), "testing_data", "signal.npy")
+VALID_CASES = [(60, 1, 500)]
+
+
+@skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required")
+class TestSignalRemoveFrequencyNumpy(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):
+ self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)
+ sig = np.load(TEST_SIGNAL)
+ t = sig.shape[1] / sampling_freq
+ composite_sig = sig + np.sin(2 * np.pi * frequency * t)
+ freqremove = SignalRemoveFrequency(frequency, quality_factor, sampling_freq)
+ freqremovesignal = freqremove(composite_sig)
+ y = np.fft.fft(composite_sig) / composite_sig.shape[1]
+ y = y[: composite_sig.shape[1] // 2]
+ y2 = np.fft.fft(freqremovesignal) / freqremovesignal.shape[1]
+ y2 = y2[: freqremovesignal.shape[1] // 2]
+ self.assertEqual(composite_sig.shape[1], sig.shape[1])
+ self.assertAlmostEqual(y.all(), y2.all())
+
+
+@skipUnless(has_scipy and has_torchaudio, "scipy and torchaudio are required")
+class TestSignalRemoveFrequencyTorch(unittest.TestCase):
+ @parameterized.expand(VALID_CASES)
+ def test_correct_parameters_multi_channels(self, frequency, quality_factor, sampling_freq):
+ self.assertIsInstance(SignalRemoveFrequency(frequency, quality_factor, sampling_freq), SignalRemoveFrequency)
+ sig = convert_to_tensor(np.load(TEST_SIGNAL))
+ t = sig.shape[1] / sampling_freq
+ composite_sig = convert_to_tensor(sig + np.sin(2 * np.pi * frequency * t))
+ freqremove = SignalRemoveFrequency(frequency, quality_factor, sampling_freq)
+ freqremovesignal = freqremove(composite_sig)
+ y = torch.fft.fft(composite_sig) / composite_sig.shape[1]
+ y = y[: composite_sig.shape[1] // 2]
+ y2 = torch.fft.fft(freqremovesignal) / freqremovesignal.shape[1]
+ y2 = y2[: freqremovesignal.shape[1] // 2]
+ self.assertEqual(composite_sig.shape[1], sig.shape[1])
+ self.assertAlmostEqual(y.all(), y2.all())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_sliding_patch_wsi_dataset.py b/tests/test_sliding_patch_wsi_dataset.py
index d639d000c5f..06395cf26cf 100644
--- a/tests/test_sliding_patch_wsi_dataset.py
+++ b/tests/test_sliding_patch_wsi_dataset.py
@@ -30,7 +30,6 @@
_, has_codec = optional_import("imagecodecs")
has_tiff = has_tiff and has_codec
-
FILE_KEY = "wsi_img"
FILE_URL = testing_data_config("images", FILE_KEY, "url")
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
@@ -143,7 +142,6 @@
],
]
-
TEST_CASE_SMALL_7 = [
{"data": [{"image": FILE_PATH_SMALL_0, WSIPatchKeys.LEVEL: 0, WSIPatchKeys.SIZE: (2, 2)}], "offset": (1, 0)},
[{"image": ARRAY_SMALL_0[:, 1:3, :2]}, {"image": ARRAY_SMALL_0[:, 1:3, 2:]}],
@@ -244,11 +242,11 @@ def test_read_patches_large(self, input_parameters, expected):
dataset = SlidingPatchWSIDataset(reader=self.backend, **input_parameters)
self.assertEqual(len(dataset), len(expected))
for i, sample in enumerate(dataset):
- self.assertEqual(sample["metadata"][WSIPatchKeys.LEVEL], expected[i]["patch_level"])
- assert_array_equal(sample["metadata"][WSIPatchKeys.SIZE], expected[i]["patch_size"])
+ self.assertEqual(sample["image"].meta[WSIPatchKeys.LEVEL], expected[i]["patch_level"])
+ assert_array_equal(sample["image"].meta[WSIPatchKeys.SIZE], expected[i]["patch_size"])
steps = [round(expected[i]["ratio"] * s) for s in expected[i]["patch_size"]]
expected_location = tuple(expected[i]["step_loc"][j] * steps[j] for j in range(len(steps)))
- assert_array_equal(sample["metadata"][WSIPatchKeys.LOCATION], expected_location)
+ assert_array_equal(sample["image"].meta[WSIPatchKeys.LOCATION], expected_location)
@skipUnless(has_cucim, "Requires cucim")
diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py
index 8b8ec47d32c..b10c1c659e7 100644
--- a/tests/test_sliding_window_inference.py
+++ b/tests/test_sliding_window_inference.py
@@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import unittest
import numpy as np
@@ -85,20 +86,20 @@ def compute(data):
expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
np.testing.assert_allclose(result.cpu().numpy(), expected_val)
- @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS])
+ @parameterized.expand(list(itertools.product(TEST_TORCH_AND_META_TENSORS, ("cpu", "cuda"), ("cpu", "cuda", None))))
@skip_if_no_cuda
- def test_sw_device(self, data_type):
- inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cpu")
+ def test_sw_device(self, data_type, device, sw_device):
+ inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device)
inputs = list_data_collate([inputs]) # make a proper batch
roi_shape = (4, 10, 7)
sw_batch_size = 10
def compute(data):
- self.assertEqual(data.device.type, "cuda")
- return data + torch.tensor(1, device="cuda")
+ self.assertEqual(data.device.type, sw_device or device)
+ return data + torch.tensor(1, device=sw_device or device)
- result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device="cuda")
- np.testing.assert_string_equal(inputs.device.type, result.device.type)
+ result = sliding_window_inference(inputs, roi_shape, sw_batch_size, compute, sw_device=sw_device, device="cpu")
+ np.testing.assert_string_equal("cpu", result.device.type)
expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1
np.testing.assert_allclose(result.cpu().numpy(), expected_val)
diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py
index e2150edce58..5760264a7bb 100644
--- a/tests/test_smartcache_patch_wsi_dataset.py
+++ b/tests/test_smartcache_patch_wsi_dataset.py
@@ -152,8 +152,7 @@ def test_read_patches(self, input_parameters, expected):
dataset.start()
i = 0
for _ in range(num_epochs):
- for j in range(len(dataset)):
- samples = dataset[j]
+ for samples in dataset:
n_patches = len(samples)
self.assert_samples_expected(samples, expected[i : i + n_patches])
i += n_patches
@@ -161,11 +160,11 @@ def test_read_patches(self, input_parameters, expected):
dataset.shutdown()
def assert_samples_expected(self, samples, expected):
- for i in range(len(samples)):
- self.assertTupleEqual(samples[i]["label"].shape, expected[i]["label"].shape)
- self.assertTupleEqual(samples[i]["image"].shape, expected[i]["image"].shape)
- self.assertIsNone(assert_array_equal(samples[i]["label"], expected[i]["label"]))
- self.assertIsNone(assert_array_equal(samples[i]["image"], expected[i]["image"]))
+ for i, item in enumerate(samples):
+ self.assertTupleEqual(item["label"].shape, expected[i]["label"].shape)
+ self.assertTupleEqual(item["image"].shape, expected[i]["image"].shape)
+ self.assertIsNone(assert_array_equal(item["label"], expected[i]["label"]))
+ self.assertIsNone(assert_array_equal(item["image"], expected[i]["image"]))
if __name__ == "__main__":
diff --git a/tests/test_smooth_field.py b/tests/test_smooth_field.py
index c67865ba393..b731af36f41 100644
--- a/tests/test_smooth_field.py
+++ b/tests/test_smooth_field.py
@@ -16,12 +16,19 @@
import torch
from parameterized import parameterized
+from monai.networks.utils import meshgrid_xy
from monai.transforms import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd
from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env
_rtol = 5e-3 if is_tf32_env() else 1e-4
-INPUT_SHAPES = ((1, 8, 8), (2, 8, 8), (1, 8, 8, 8))
+x, y = meshgrid_xy(torch.linspace(-1, 2, 11), torch.linspace(-2.1, 1.2, 8))
+pattern2d = x.pow(2).add_(y.pow(2)).sqrt_()
+
+x, y, z = meshgrid_xy(torch.linspace(-1, 2, 11), torch.linspace(-2.1, 1.2, 8), torch.linspace(-0.1, 10.2, 6))
+pattern3d = x.pow(2).add_(y.pow(2)).add_(z.pow(2)).sqrt_()
+
+INPUT_SHAPES = ((1, 8, 8), (1, 12, 7), (2, 8, 8), (2, 13, 8), (1, 8, 8, 8), (3, 7, 4, 5))
TESTS_CONTRAST = []
TESTS_INTENSITY = []
@@ -131,6 +138,27 @@ def test_rand_smooth_deformd(self, input_param, input_data, expected_val):
expected = expected_val[key]
assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test="tensor")
+ def test_rand_smooth_nodeformd(self):
+ """Test input is very close to output when deformation is very low, verifies there's no transposition."""
+
+ for label, im in zip(("2D", "3D"), (pattern2d, pattern3d)):
+ with self.subTest(f"Testing {label} case with shape {im.shape}"):
+ rsize = (3,) * len(im.shape)
+ g = RandSmoothDeformd(
+ keys=(KEY,), spatial_size=im.shape, rand_size=rsize, prob=1.0, device=device, def_range=1e-20
+ )
+ g.set_random_state(123)
+
+ expected_val = {KEY: im[None]}
+
+ res = g(expected_val)
+ for key, result in res.items():
+ expected = expected_val[key]
+
+ self.assertSequenceEqual(tuple(result.shape), tuple(expected.shape))
+
+ assert_allclose(result, expected, rtol=_rtol, atol=1e-1, type_test="tensor")
+
def test_rand_smooth_deformd_pad(self):
input_param, input_data, expected_val = TESTS_DEFORM[0]
diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py
new file mode 100644
index 00000000000..ba092516b7c
--- /dev/null
+++ b/tests/test_sobel_gradient.py
@@ -0,0 +1,187 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.transforms import SobelGradients
+from tests.utils import assert_allclose
+
+IMAGE = torch.zeros(1, 16, 16, dtype=torch.float32)
+IMAGE[0, 8, :] = 1
+
+# Output with reflect padding
+OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
+OUTPUT_3x3[1, 7, :] = 0.5
+OUTPUT_3x3[1, 9, :] = -0.5
+
+# Output with zero padding
+OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()
+OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125
+OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25
+OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125
+OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25
+OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0
+OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0
+
+TEST_CASE_0 = [IMAGE, {"kernel_size": 3, "dtype": torch.float32}, OUTPUT_3x3]
+TEST_CASE_1 = [IMAGE, {"kernel_size": 3, "dtype": torch.float64}, OUTPUT_3x3]
+TEST_CASE_2 = [IMAGE, {"kernel_size": 3, "spatial_axes": 0, "dtype": torch.float64}, OUTPUT_3x3[0:1]]
+TEST_CASE_3 = [IMAGE, {"kernel_size": 3, "spatial_axes": 1, "dtype": torch.float64}, OUTPUT_3x3[1:2]]
+TEST_CASE_4 = [IMAGE, {"kernel_size": 3, "spatial_axes": [1], "dtype": torch.float64}, OUTPUT_3x3[1:2]]
+TEST_CASE_5 = [
+ IMAGE,
+ {"kernel_size": 3, "spatial_axes": [0, 1], "normalize_kernels": True, "dtype": torch.float64},
+ OUTPUT_3x3,
+]
+TEST_CASE_6 = [
+ IMAGE,
+ {"kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "reflect", "dtype": torch.float64},
+ OUTPUT_3x3,
+]
+TEST_CASE_7 = [
+ IMAGE,
+ {"kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "zeros", "dtype": torch.float64},
+ OUTPUT_3x3_ZERO_PAD,
+]
+TEST_CASE_8 = [ # Non-normalized kernels
+ IMAGE,
+ {"kernel_size": 3, "normalize_kernels": False, "dtype": torch.float32},
+ OUTPUT_3x3 * 8.0,
+]
+TEST_CASE_9 = [ # Normalized gradients and normalized kernels
+ IMAGE,
+ {
+ "kernel_size": 3,
+ "normalize_kernels": True,
+ "normalize_gradients": True,
+ "spatial_axes": (0, 1),
+ "dtype": torch.float64,
+ },
+ torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]),
+]
+TEST_CASE_10 = [ # Normalized gradients but non-normalized kernels
+ IMAGE,
+ {
+ "kernel_size": 3,
+ "normalize_kernels": False,
+ "normalize_gradients": True,
+ "spatial_axes": (0, 1),
+ "dtype": torch.float64,
+ },
+ torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]),
+]
+
+TEST_CASE_KERNEL_0 = [
+ {"kernel_size": 3, "dtype": torch.float64},
+ (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)),
+]
+TEST_CASE_KERNEL_1 = [
+ {"kernel_size": 5, "dtype": torch.float64},
+ (
+ torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64),
+ torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_2 = [
+ {"kernel_size": 7, "dtype": torch.float64},
+ (
+ torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64),
+ torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_0 = [
+ {"kernel_size": 3, "normalize_kernels": False, "dtype": torch.float64},
+ (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_1 = [
+ {"kernel_size": 5, "normalize_kernels": False, "dtype": torch.float64},
+ (
+ torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64),
+ torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_2 = [
+ {"kernel_size": 7, "normalize_kernels": False, "dtype": torch.float64},
+ (
+ torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64),
+ torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64),
+ ),
+]
+
+TEST_CASE_ERROR_0 = [IMAGE, {"kernel_size": 1}] # kernel size less than 3
+TEST_CASE_ERROR_1 = [IMAGE, {"kernel_size": 4}] # even kernel size
+TEST_CASE_ERROR_2 = [IMAGE, {"spatial_axes": "horizontal"}] # wrong type direction
+TEST_CASE_ERROR_3 = [IMAGE, {"spatial_axes": 3}] # wrong direction
+TEST_CASE_ERROR_4 = [IMAGE, {"spatial_axes": [3]}] # wrong direction in a list
+TEST_CASE_ERROR_5 = [IMAGE, {"spatial_axes": [0, 4]}] # correct and wrong direction in a list
+
+
+class SobelGradientTests(unittest.TestCase):
+ backend = None
+
+ @parameterized.expand(
+ [
+ TEST_CASE_0,
+ TEST_CASE_1,
+ TEST_CASE_2,
+ TEST_CASE_3,
+ TEST_CASE_4,
+ TEST_CASE_5,
+ TEST_CASE_6,
+ TEST_CASE_7,
+ TEST_CASE_8,
+ TEST_CASE_9,
+ TEST_CASE_10,
+ ]
+ )
+ def test_sobel_gradients(self, image, arguments, expected_grad):
+ sobel = SobelGradients(**arguments)
+ grad = sobel(image)
+ assert_allclose(grad, expected_grad)
+
+ @parameterized.expand(
+ [
+ TEST_CASE_KERNEL_0,
+ TEST_CASE_KERNEL_1,
+ TEST_CASE_KERNEL_2,
+ TEST_CASE_KERNEL_NON_NORMALIZED_0,
+ TEST_CASE_KERNEL_NON_NORMALIZED_1,
+ TEST_CASE_KERNEL_NON_NORMALIZED_2,
+ ]
+ )
+ def test_sobel_kernels(self, arguments, expected_kernels):
+ sobel = SobelGradients(**arguments)
+ self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype)
+ self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype)
+ assert_allclose(sobel.kernel_diff, expected_kernels[0])
+ assert_allclose(sobel.kernel_smooth, expected_kernels[1])
+
+ @parameterized.expand(
+ [
+ TEST_CASE_ERROR_0,
+ TEST_CASE_ERROR_1,
+ TEST_CASE_ERROR_2,
+ TEST_CASE_ERROR_3,
+ TEST_CASE_ERROR_4,
+ TEST_CASE_ERROR_5,
+ ]
+ )
+ def test_sobel_gradients_error(self, image, arguments):
+ with self.assertRaises(ValueError):
+ sobel = SobelGradients(**arguments)
+ sobel(image)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py
new file mode 100644
index 00000000000..b53812aeb95
--- /dev/null
+++ b/tests/test_sobel_gradientd.py
@@ -0,0 +1,210 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.transforms import SobelGradientsd
+from tests.utils import assert_allclose
+
+IMAGE = torch.zeros(1, 16, 16, dtype=torch.float32)
+IMAGE[0, 8, :] = 1
+
+# Output with reflect padding
+OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
+OUTPUT_3x3[1, 7, :] = 0.5
+OUTPUT_3x3[1, 9, :] = -0.5
+
+# Output with zero padding
+OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()
+OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125
+OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25
+OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125
+OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25
+OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0
+OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0
+
+TEST_CASE_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float32}, {"image": OUTPUT_3x3}]
+TEST_CASE_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, {"image": OUTPUT_3x3}]
+TEST_CASE_2 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "dtype": torch.float32, "new_key_prefix": "sobel_"},
+ {"sobel_image": OUTPUT_3x3},
+]
+TEST_CASE_3 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": 0, "dtype": torch.float32},
+ {"image": OUTPUT_3x3[0][None, ...]},
+]
+TEST_CASE_4 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": 1, "dtype": torch.float32},
+ {"image": OUTPUT_3x3[1][None, ...]},
+]
+TEST_CASE_5 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": [1], "dtype": torch.float32},
+ {"image": OUTPUT_3x3[1][None, ...]},
+]
+TEST_CASE_6 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": [0, 1], "normalize_kernels": True, "dtype": torch.float32},
+ {"image": OUTPUT_3x3},
+]
+TEST_CASE_7 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "reflect", "dtype": torch.float32},
+ {"image": OUTPUT_3x3},
+]
+TEST_CASE_8 = [
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "padding_mode": "zeros", "dtype": torch.float32},
+ {"image": OUTPUT_3x3_ZERO_PAD},
+]
+TEST_CASE_9 = [ # Non-normalized kernels
+ {"image": IMAGE},
+ {"keys": "image", "kernel_size": 3, "spatial_axes": (0, 1), "normalize_kernels": False, "dtype": torch.float32},
+ {"image": OUTPUT_3x3 * 8.0},
+]
+TEST_CASE_10 = [ # Normalized gradients and normalized kernels
+ {"image": IMAGE},
+ {
+ "keys": "image",
+ "kernel_size": 3,
+ "spatial_axes": (0, 1),
+ "normalize_kernels": True,
+ "normalize_gradients": True,
+ "dtype": torch.float32,
+ },
+ {"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])},
+]
+TEST_CASE_11 = [ # Normalized gradients but non-normalized kernels
+ {"image": IMAGE},
+ {
+ "keys": "image",
+ "kernel_size": 3,
+ "spatial_axes": (0, 1),
+ "normalize_kernels": False,
+ "normalize_gradients": True,
+ "dtype": torch.float32,
+ },
+ {"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])},
+]
+
+TEST_CASE_KERNEL_0 = [
+ {"keys": "image", "kernel_size": 3, "dtype": torch.float64},
+ (torch.tensor([-0.5, 0.0, 0.5], dtype=torch.float64), torch.tensor([0.25, 0.5, 0.25], dtype=torch.float64)),
+]
+TEST_CASE_KERNEL_1 = [
+ {"keys": "image", "kernel_size": 5, "dtype": torch.float64},
+ (
+ torch.tensor([-0.1250, -0.2500, 0.0000, 0.2500, 0.1250], dtype=torch.float64),
+ torch.tensor([0.0625, 0.2500, 0.3750, 0.2500, 0.0625], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_2 = [
+ {"keys": "image", "kernel_size": 7, "dtype": torch.float64},
+ (
+ torch.tensor([-0.03125, -0.125, -0.15625, 0.0, 0.15625, 0.125, 0.03125], dtype=torch.float64),
+ torch.tensor([0.015625, 0.09375, 0.234375, 0.3125, 0.234375, 0.09375, 0.015625], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_0 = [
+ {"keys": "image", "kernel_size": 3, "normalize_kernels": False, "dtype": torch.float64},
+ (torch.tensor([-1.0, 0.0, 1.0], dtype=torch.float64), torch.tensor([1.0, 2.0, 1.0], dtype=torch.float64)),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_1 = [
+ {"keys": "image", "kernel_size": 5, "normalize_kernels": False, "dtype": torch.float64},
+ (
+ torch.tensor([-1.0, -2.0, 0.0, 2.0, 1.0], dtype=torch.float64),
+ torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0], dtype=torch.float64),
+ ),
+]
+TEST_CASE_KERNEL_NON_NORMALIZED_2 = [
+ {"keys": "image", "kernel_size": 7, "normalize_kernels": False, "dtype": torch.float64},
+ (
+ torch.tensor([-1.0, -4.0, -5.0, 0.0, 5.0, 4.0, 1.0], dtype=torch.float64),
+ torch.tensor([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0], dtype=torch.float64),
+ ),
+]
+TEST_CASE_ERROR_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 1}] # kernel size less than 3
+TEST_CASE_ERROR_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 4}] # even kernel size
+TEST_CASE_ERROR_2 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": "horizontal"}] # wrong type direction
+TEST_CASE_ERROR_3 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": 3}] # wrong direction
+TEST_CASE_ERROR_4 = [{"image": IMAGE}, {"keys": "image", "spatial_axes": [3]}] # wrong direction in a list
+TEST_CASE_ERROR_5 = [
+ {"image": IMAGE},
+ {"keys": "image", "spatial_axes": [0, 4]},
+] # correct and wrong direction in a list
+
+
+class SobelGradientTests(unittest.TestCase):
+ backend = None
+
+ @parameterized.expand(
+ [
+ TEST_CASE_0,
+ TEST_CASE_1,
+ TEST_CASE_2,
+ TEST_CASE_3,
+ TEST_CASE_4,
+ TEST_CASE_5,
+ TEST_CASE_6,
+ TEST_CASE_7,
+ TEST_CASE_8,
+ TEST_CASE_9,
+ TEST_CASE_10,
+ TEST_CASE_11,
+ ]
+ )
+ def test_sobel_gradients(self, image_dict, arguments, expected_grad):
+ sobel = SobelGradientsd(**arguments)
+ grad = sobel(image_dict)
+ key = "image" if "new_key_prefix" not in arguments else arguments["new_key_prefix"] + arguments["keys"]
+ assert_allclose(grad[key], expected_grad[key])
+
+ @parameterized.expand(
+ [
+ TEST_CASE_KERNEL_0,
+ TEST_CASE_KERNEL_1,
+ TEST_CASE_KERNEL_2,
+ TEST_CASE_KERNEL_NON_NORMALIZED_0,
+ TEST_CASE_KERNEL_NON_NORMALIZED_1,
+ TEST_CASE_KERNEL_NON_NORMALIZED_2,
+ ]
+ )
+ def test_sobel_kernels(self, arguments, expected_kernels):
+ sobel = SobelGradientsd(**arguments)
+ self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype)
+ self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype)
+ assert_allclose(sobel.kernel_diff, expected_kernels[0])
+ assert_allclose(sobel.kernel_smooth, expected_kernels[1])
+
+ @parameterized.expand(
+ [
+ TEST_CASE_ERROR_0,
+ TEST_CASE_ERROR_1,
+ TEST_CASE_ERROR_2,
+ TEST_CASE_ERROR_3,
+ TEST_CASE_ERROR_4,
+ TEST_CASE_ERROR_5,
+ ]
+ )
+ def test_sobel_gradients_error(self, image_dict, arguments):
+ with self.assertRaises(ValueError):
+ sobel = SobelGradientsd(**arguments)
+ sobel(image_dict)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_spacing.py b/tests/test_spacing.py
index 244f14921c0..ba44bf76f22 100644
--- a/tests/test_spacing.py
+++ b/tests/test_spacing.py
@@ -19,7 +19,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import affine_to_spacing
from monai.transforms import Spacing
-from monai.utils import ensure_tuple, fall_back_tuple
+from monai.utils import fall_back_tuple
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose
TESTS = []
@@ -44,6 +44,16 @@
*device,
]
)
+ TESTS.append(
+ [
+ {"pixdim": 2.0, "padding_mode": "zeros", "dtype": float},
+ torch.arange(4).reshape((1, 2, 2)) + 1.0, # data
+ torch.eye(4),
+ {},
+ torch.tensor([[[1.0, 0.0], [0.0, 0.0]]]),
+ *device,
+ ]
+ )
TESTS.append(
[
{"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float},
@@ -220,12 +230,36 @@
*device,
]
)
+ TESTS.append( # 5D input
+ [
+ {"pixdim": 0.5, "padding_mode": "zeros", "mode": "nearest", "scale_extent": True},
+ torch.ones((1, 368, 336, 368)), # data
+ torch.tensor(
+ [
+ [0.41, 0.005, 0.008, -79.7],
+ [-0.0049, 0.592, 0.0664, -57.4],
+ [-0.0073, -0.0972, 0.404, -32.1],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ ),
+ {},
+ torch.ones((1, 302, 403, 301)),
+ *device,
+ ]
+ )
TESTS_TORCH = []
for track_meta in (False, True):
for p in TEST_NDARRAYS_ALL:
TESTS_TORCH.append([[1.2, 1.3, 0.9], p(torch.zeros((1, 3, 4, 5))), track_meta])
+TEST_INVERSE = []
+for d in TEST_DEVICES:
+ for recompute in (False, True):
+ for align in (False, True):
+ for scale_extent in (False, True):
+ TEST_INVERSE.append([*d, recompute, align, scale_extent])
+
class TestSpacingCase(unittest.TestCase):
@parameterized.expand(TESTS)
@@ -238,7 +272,6 @@ def test_spacing(self, init_param, img, affine, data_param, expected_output, dev
sr = min(len(res.shape) - 1, 3)
if isinstance(init_param["pixdim"], float):
init_param["pixdim"] = [init_param["pixdim"]] * sr
- init_pixdim = ensure_tuple(init_param["pixdim"])
init_pixdim = init_param["pixdim"][:sr]
norm = affine_to_spacing(res.affine, sr).cpu().numpy()
assert_allclose(fall_back_tuple(init_pixdim, norm), norm, type_test=False)
@@ -259,15 +292,15 @@ def test_spacing_torch(self, pixdim, img, track_meta: bool):
self.assertNotEqual(img.shape, res.shape)
set_track_meta(True)
- @parameterized.expand(TEST_DEVICES)
- def test_inverse(self, device):
+ @parameterized.expand(TEST_INVERSE)
+ def test_inverse(self, device, recompute, align, scale_extent):
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
affine = torch.tensor(
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
)
meta = {"fname": "somewhere"}
img = MetaTensor(img_t, affine=affine, meta=meta)
- tr = Spacing(pixdim=[1.1, 1.2, 0.9])
+ tr = Spacing(pixdim=[1.1, 1.2, 0.9], recompute_affine=recompute, align_corners=align, scale_extent=scale_extent)
# check that image and affine have changed
img = tr(img)
self.assertNotEqual(img.shape, img_t.shape)
@@ -280,6 +313,33 @@ def test_inverse(self, device):
l2_norm_affine = ((affine - img.affine) ** 2).sum() ** 0.5
self.assertLess(l2_norm_affine, 5e-2)
+ @parameterized.expand(TEST_INVERSE)
+ def test_inverse_mn_mx(self, device, recompute, align, scale_extent):
+ img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
+ affine = torch.tensor(
+ [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
+ )
+ img = MetaTensor(img_t, affine=affine, meta={"fname": "somewhere"})
+ choices = [(None, None), [1.2, None], [None, 0.7], [0.7, 0.9]]
+ idx = np.random.choice(range(len(choices)), size=1)[0]
+ tr = Spacing(
+ pixdim=[1.1, 1.2, 0.9],
+ recompute_affine=recompute,
+ align_corners=align,
+ scale_extent=scale_extent,
+ min_pixdim=[0.9, None, choices[idx][0]],
+ max_pixdim=[1.1, 1.1, choices[idx][1]],
+ )
+ img_out = tr(img)
+ if isinstance(img_out, MetaTensor):
+ assert_allclose(
+ img_out.pixdim, [1.0, 1.125, 0.888889] if recompute else [1.0, 1.2, 0.9], type_test=False, rtol=1e-4
+ )
+ img_out = tr.inverse(img_out)
+ self.assertEqual(img_out.applied_operations, [])
+ self.assertEqual(img_out.shape, img_t.shape)
+ self.assertLess(((affine - img_out.affine) ** 2).sum() ** 0.5, 5e-2)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_spacingd.py b/tests/test_spacingd.py
index d3c7bbc6295..22729fd1b27 100644
--- a/tests/test_spacingd.py
+++ b/tests/test_spacingd.py
@@ -81,7 +81,6 @@
)
)
-
TESTS_TORCH = []
for track_meta in (False, True):
for device in TEST_DEVICES:
diff --git a/tests/test_spatial_resample.py b/tests/test_spatial_resample.py
index 63260373d0e..30bf33149bc 100644
--- a/tests/test_spatial_resample.py
+++ b/tests/test_spatial_resample.py
@@ -10,21 +10,21 @@
# limitations under the License.
import unittest
+from copy import deepcopy
import numpy as np
import torch
from parameterized import parameterized
-from monai.config import USE_COMPILED
from monai.data.meta_obj import set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms import SpatialResample
+from monai.utils import optional_import
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose
TESTS = []
-
destinations_3d = [
torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),
torch.tensor([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),
@@ -37,7 +37,7 @@
for dst, expct in zip(destinations_3d, expected_3d):
for device in TEST_DEVICES:
for align in (False, True):
- interp = ("nearest", "bilinear", 0, 1) if align and USE_COMPILED else ("nearest", "bilinear")
+ interp = ("nearest", "bilinear")
for interp_mode in interp:
for padding_mode in ("zeros", "border", "reflection"):
TESTS.append(
@@ -54,7 +54,9 @@
expct,
]
)
-
+if optional_import("cupy")[1] and optional_import("scipy.ndimage")[1]:
+ TESTS.append(deepcopy(TESTS[-1]))
+ TESTS[-1][2].update({"align_corners": True, "mode": 1, "padding_mode": "reflect"}) # type: ignore
destinations_2d = [
torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 1.0], [0.0, 0.0, 1.0]]), # flip the second
@@ -148,7 +150,7 @@ def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):
dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])
dst = dst.to(dtype)
- out = SpatialResample(dtype=dtype)(img=img, dst_affine=dst)
+ out = SpatialResample(dtype=dtype, align_corners=True)(img=img, dst_affine=dst, align_corners=False)
assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)
assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)
@@ -164,10 +166,12 @@ def test_ill_affine(self, device):
img.affine = ill_affine
dst_affine = torch.eye(4)
SpatialResample()(img=img, dst_affine=dst_affine)
+ if not (optional_import("scipy")[1] and optional_import("cupy")[1]):
+ return
+ with self.assertRaises(ValueError): # requires scipy
+ SpatialResample(mode=1, align_corners=True)(img=img, dst_affine=dst_affine)
with self.assertRaises(ValueError):
- img.affine = torch.eye(4)
- dst_affine = torch.eye(4) * 0.1
- SpatialResample(mode=None)(img=img, dst_affine=dst_affine)
+ SpatialResample(mode=1, align_corners=False)(img=img, dst_affine=dst_affine)
@parameterized.expand(TEST_TORCH_INPUT)
def test_input_torch(self, new_shape, tile, device, dtype, expected_data, track_meta):
diff --git a/tests/test_spatial_resampled.py b/tests/test_spatial_resampled.py
index 3772cf0ddfe..5ace0b37740 100644
--- a/tests/test_spatial_resampled.py
+++ b/tests/test_spatial_resampled.py
@@ -15,7 +15,6 @@
import torch
from parameterized import parameterized
-from monai.config import USE_COMPILED
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.spatial.dictionary import SpatialResampled
@@ -23,7 +22,6 @@
TESTS = []
-
destinations_3d = [
torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),
torch.tensor([[-1.0, 0.0, 0.0, 1.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]),
@@ -37,7 +35,7 @@
for device in TEST_DEVICES:
for align in (True, False):
for dtype in (torch.float32, torch.float64):
- interp = ("nearest", "bilinear", 0, 1) if align and USE_COMPILED else ("nearest", "bilinear")
+ interp = ("nearest", "bilinear")
for interp_mode in interp:
for padding_mode in ("zeros", "border", "reflection"):
TESTS.append(
diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py
index 1e39439b86e..ee8cc043e4a 100644
--- a/tests/test_splitdimd.py
+++ b/tests/test_splitdimd.py
@@ -25,7 +25,8 @@
for p in TEST_NDARRAYS:
for keepdim in (True, False):
for update_meta in (True, False):
- TESTS.append((keepdim, p, update_meta))
+ for list_output in (True, False):
+ TESTS.append((keepdim, p, update_meta, list_output))
class TestSplitDimd(unittest.TestCase):
@@ -39,14 +40,18 @@ def setUpClass(cls):
cls.data: MetaTensor = loader(data)
@parameterized.expand(TESTS)
- def test_correct(self, keepdim, im_type, update_meta):
+ def test_correct(self, keepdim, im_type, update_meta, list_output):
data = deepcopy(self.data)
data["i"] = im_type(data["i"])
arr = data["i"]
for dim in range(arr.ndim):
- out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data)
- self.assertIsInstance(out, dict)
- self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim])
+ out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta, list_output=list_output)(data)
+ if list_output:
+ self.assertIsInstance(out, list)
+ self.assertEqual(len(out), arr.shape[dim])
+ else:
+ self.assertIsInstance(out, dict)
+ self.assertEqual(len(out.keys()), len(data.keys()) + arr.shape[dim])
# if updating metadata, pick some random points and
# check same world coordinates between input and output
if update_meta:
@@ -55,14 +60,20 @@ def test_correct(self, keepdim, im_type, update_meta):
split_im_idx = idx[dim]
split_idx = deepcopy(idx)
split_idx[dim] = 0
- split_im = out[f"i_{split_im_idx}"]
+ if list_output:
+ split_im = out[split_im_idx]["i"]
+ else:
+ split_im = out[f"i_{split_im_idx}"]
if isinstance(data, MetaTensor) and isinstance(split_im, MetaTensor):
# idx[1:] to remove channel and then add 1 for 4th element
real_world = data.affine @ torch.tensor(idx[1:] + [1]).double()
real_world2 = split_im.affine @ torch.tensor(split_idx[1:] + [1]).double()
assert_allclose(real_world, real_world2)
- out = out["i_0"]
+ if list_output:
+ out = out[0]["i"]
+ else:
+ out = out["i_0"]
expected_ndim = arr.ndim if keepdim else arr.ndim - 1
self.assertEqual(out.ndim, expected_ndim)
# assert is a shallow copy
diff --git a/tests/test_squeezedim.py b/tests/test_squeezedim.py
index 8403efe8366..a2b538d58c0 100644
--- a/tests/test_squeezedim.py
+++ b/tests/test_squeezedim.py
@@ -14,8 +14,9 @@
import numpy as np
from parameterized import parameterized
+from monai.data import MetaTensor
from monai.transforms import SqueezeDim
-from tests.utils import TEST_NDARRAYS
+from tests.utils import TEST_NDARRAYS, assert_allclose
TESTS, TESTS_FAIL = [], []
for p in TEST_NDARRAYS:
@@ -34,6 +35,8 @@ def test_shape(self, input_param, test_data, expected_shape):
result = SqueezeDim(**input_param)(test_data)
self.assertTupleEqual(result.shape, expected_shape)
+ if "dim" in input_param and input_param["dim"] == 2 and isinstance(result, MetaTensor):
+ assert_allclose(result.affine.shape, [3, 3])
@parameterized.expand(TESTS_FAIL)
def test_invalid_inputs(self, exception, input_param, test_data):
@@ -41,6 +44,19 @@ def test_invalid_inputs(self, exception, input_param, test_data):
with self.assertRaises(exception):
SqueezeDim(**input_param)(test_data)
+ def test_affine_ill_inputs(self):
+ img = MetaTensor(
+ np.random.rand(1, 2, 1, 3),
+ affine=[
+ [-0.7422, 0.0, 0.0, 186.3210],
+ [0.0, 0.0, -3.0, 70.6580],
+ [0.0, -0.7422, 0.0, 189.4130],
+ [0.0, 0.0, 0.0, 1.0],
+ ],
+ )
+ with self.assertWarns(UserWarning):
+ SqueezeDim(dim=2)(img)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_squeezedimd.py b/tests/test_squeezedimd.py
index 6baf4696a5d..5908e7673f3 100644
--- a/tests/test_squeezedimd.py
+++ b/tests/test_squeezedimd.py
@@ -14,8 +14,9 @@
import numpy as np
from parameterized import parameterized
+from monai.data import MetaTensor
from monai.transforms import SqueezeDimd
-from tests.utils import TEST_NDARRAYS
+from tests.utils import TEST_NDARRAYS, assert_allclose
TESTS, TESTS_FAIL = [], []
for p in TEST_NDARRAYS:
@@ -82,6 +83,8 @@ def test_shape(self, input_param, test_data, expected_shape):
result = SqueezeDimd(**input_param)(test_data)
self.assertTupleEqual(result["img"].shape, expected_shape)
self.assertTupleEqual(result["seg"].shape, expected_shape)
+ if "dim" in input_param and isinstance(result["img"], MetaTensor) and input_param["dim"] == 2:
+ assert_allclose(result["img"].affine.shape, [3, 3])
@parameterized.expand(TESTS_FAIL)
def test_invalid_inputs(self, exception, input_param, test_data):
diff --git a/tests/test_std_shift_intensity.py b/tests/test_std_shift_intensity.py
index a5549bf187d..b8306aa09c6 100644
--- a/tests/test_std_shift_intensity.py
+++ b/tests/test_std_shift_intensity.py
@@ -12,11 +12,10 @@
import unittest
import numpy as np
-import torch
from monai.transforms import ShiftIntensity, StdShiftIntensity
from monai.utils import dtype_numpy_to_torch
-from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D
+from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
class TestStdShiftIntensity(NumpyImageTestCase2D):
@@ -29,7 +28,7 @@ def test_value(self):
expected = shifter(imt)
std_shifter = StdShiftIntensity(factor=factor)
result = std_shifter(imt)
- torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5)
+ assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)
def test_zerostd(self):
for p in TEST_NDARRAYS:
@@ -39,7 +38,7 @@ def test_zerostd(self):
factor = np.random.rand()
std_shifter = StdShiftIntensity(factor=factor, nonzero=nonzero, channel_wise=channel_wise)
result = std_shifter(image)
- torch.testing.assert_allclose(result, image, atol=0, rtol=1e-5)
+ assert_allclose(result, image, atol=0, rtol=1e-5, type_test=False)
def test_nonzero(self):
for p in TEST_NDARRAYS:
@@ -48,7 +47,7 @@ def test_nonzero(self):
std_shifter = StdShiftIntensity(factor=factor, nonzero=True)
result = std_shifter(image)
expected = p(np.asarray([[4 + factor, 0, 2 + factor], [0, 2 + factor, 4 + factor]], dtype=np.float32))
- torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5)
+ assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)
def test_channel_wise(self):
for p in TEST_NDARRAYS:
@@ -59,7 +58,7 @@ def test_channel_wise(self):
expected = p(
np.stack((np.asarray([1 + 0.5 * factor, 2 + 0.5 * factor]), np.asarray([1, 1]))).astype(np.float32)
)
- torch.testing.assert_allclose(result, expected, atol=0, rtol=1e-5)
+ assert_allclose(result, expected, atol=0, rtol=1e-5, type_test=False)
def test_dtype(self):
trans_dtype = np.float32
diff --git a/tests/test_str2bool.py b/tests/test_str2bool.py
new file mode 100644
index 00000000000..e1d9ca1ee38
--- /dev/null
+++ b/tests/test_str2bool.py
@@ -0,0 +1,31 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from monai.utils.misc import str2bool
+
+
+class TestStr2Bool(unittest.TestCase):
+ def test_str_2_bool(self):
+ for i in ("yes", "true", "t", "y", "1", True):
+ self.assertTrue(str2bool(i))
+ for i in ("no", "false", "f", "n", "0", False):
+ self.assertFalse(str2bool(i))
+ for bad_value in ("test", 0, 1, 2, None):
+ self.assertFalse(str2bool(bad_value, default=False, raise_exc=False))
+ self.assertTrue(str2bool(bad_value, default=True, raise_exc=False))
+ with self.assertRaises(ValueError):
+ self.assertTrue(str2bool(bad_value))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_str2list.py b/tests/test_str2list.py
new file mode 100644
index 00000000000..95a4dcaef01
--- /dev/null
+++ b/tests/test_str2list.py
@@ -0,0 +1,30 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from monai.utils.misc import str2list
+
+
+class TestStr2List(unittest.TestCase):
+ def test_str_2_list(self):
+ for i in ("1,2,3", "1, 2, 3", "1,2e-0,3.0", [1, 2, 3]):
+ self.assertEqual(str2list(i), [1, 2, 3])
+ for i in ("1,2,3", "1,2,3,4.3", [1, 2, 3, 4.001]):
+ self.assertNotEqual(str2list(i), [1, 2, 3, 4])
+ for bad_value in ((1, 3), int):
+ self.assertIsNone(str2list(bad_value, raise_exc=False))
+ with self.assertRaises(ValueError):
+ self.assertIsNone(str2list(bad_value))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py
index 3e5370473cb..bd46aecb970 100644
--- a/tests/test_subpixel_upsample.py
+++ b/tests/test_subpixel_upsample.py
@@ -57,7 +57,6 @@
TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA)
TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA)
-
# add every test back with the pad/pool sequential component omitted
for tests in list(TEST_CASE_SUBPIXEL):
args: dict = tests[0] # type: ignore
diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py
index d5aa1af688c..93a569186d7 100644
--- a/tests/test_testtimeaugmentation.py
+++ b/tests/test_testtimeaugmentation.py
@@ -92,7 +92,6 @@ def test_test_time_augmentation(self):
scale_range=((0.8, 1), (0.8, 1)),
padding_mode="zeros",
mode=("bilinear", "nearest"),
- as_tensor_output=False,
),
CropForegroundd(keys, source_key="image"),
DivisiblePadd(keys, 4),
diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py
index 013c20f4ce2..87da22eab31 100644
--- a/tests/test_thread_buffer.py
+++ b/tests/test_thread_buffer.py
@@ -13,9 +13,12 @@
import time
import unittest
+import torch
+
from monai.data import DataLoader, Dataset, ThreadBuffer, ThreadDataLoader
from monai.transforms import Compose, SimulateDelayd
-from monai.utils import PerfContext
+from monai.utils import PerfContext, set_determinism
+from tests.utils import assert_allclose
class TestDataLoader(unittest.TestCase):
@@ -53,6 +56,19 @@ def test_dataloader(self):
self.assertEqual(d["label"][0], "spleen_label_19.nii.gz")
self.assertEqual(d["label"][1], "spleen_label_31.nii.gz")
+ def test_deterministic(self):
+ set_determinism(0)
+ res_1 = list(ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=2, shuffle=True, num_workers=0))
+
+ set_determinism(0)
+ num_workers = 2 if sys.platform == "linux" else 1
+ res_2 = list(
+ ThreadDataLoader(torch.arange(5), batch_size=2, buffer_size=3, shuffle=True, num_workers=num_workers)
+ )
+
+ set_determinism(None)
+ assert_allclose(torch.cat(res_1), torch.cat(res_2), type_test=False)
+
def test_time(self):
dataset = Dataset(data=self.datalist * 2, transform=self.transform) # contains data for 2 batches
dataloader = DataLoader(dataset=dataset, batch_size=2, num_workers=0)
diff --git a/tests/test_to_device.py b/tests/test_to_device.py
index 70f1ea8828b..9f781193266 100644
--- a/tests/test_to_device.py
+++ b/tests/test_to_device.py
@@ -15,7 +15,7 @@
from parameterized import parameterized
from monai.transforms import ToDevice
-from tests.utils import skip_if_no_cuda
+from tests.utils import assert_allclose, skip_if_no_cuda
TEST_CASE_1 = ["cuda:0"]
@@ -33,7 +33,7 @@ def test_value(self, device):
converter = ToDevice(device=device, non_blocking=True)
data = torch.tensor([1, 2, 3, 4])
ret = converter(data)
- torch.testing.assert_allclose(ret, data.to(device))
+ assert_allclose(ret, data.to(device))
if __name__ == "__main__":
diff --git a/tests/test_to_deviced.py b/tests/test_to_deviced.py
index 7d075ad365b..b3ee4905666 100644
--- a/tests/test_to_deviced.py
+++ b/tests/test_to_deviced.py
@@ -15,7 +15,7 @@
from monai.data import CacheDataset, ThreadDataLoader
from monai.transforms import ToDeviced
-from tests.utils import skip_if_no_cuda
+from tests.utils import assert_allclose, skip_if_no_cuda
@skip_if_no_cuda
@@ -28,7 +28,7 @@ def test_value(self):
)
dataloader = ThreadDataLoader(dataset=dataset, num_workers=0, batch_size=1)
for i, d in enumerate(dataloader):
- torch.testing.assert_allclose(d["img"], torch.tensor([i], device=device))
+ assert_allclose(d["img"], torch.tensor([i], device=device))
if __name__ == "__main__":
diff --git a/tests/test_to_tensord.py b/tests/test_to_tensord.py
new file mode 100644
index 00000000000..4c1f2172ae1
--- /dev/null
+++ b/tests/test_to_tensord.py
@@ -0,0 +1,69 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms import ToTensord
+from tests.utils import HAS_CUPY, TEST_NDARRAYS, assert_allclose, optional_import
+
+cp, _ = optional_import("cupy")
+
+im = [[1, 2], [3, 4]]
+
+TESTS = [(im, (2, 2))]
+for p in TEST_NDARRAYS:
+ TESTS.append((p(im), (2, 2)))
+
+TESTS_SINGLE = [[5]]
+for p in TEST_NDARRAYS:
+ TESTS_SINGLE.append([p(5)])
+
+
+class TestToTensord(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_array_input(self, test_data, expected_shape):
+ test_data = {"img": test_data}
+ to_tensord = ToTensord(keys="img", dtype=torch.float32, device="cpu", wrap_sequence=True)
+ result = to_tensord(test_data)
+ out_img = result["img"]
+ self.assertTrue(isinstance(out_img, torch.Tensor))
+ assert_allclose(out_img, test_data["img"], type_test=False)
+ self.assertTupleEqual(out_img.shape, expected_shape)
+
+ # test inverse
+ inv_data = to_tensord.inverse(result)
+ self.assertTrue(isinstance(inv_data["img"], np.ndarray))
+ assert_allclose(test_data["img"], inv_data["img"], type_test=False)
+
+ @parameterized.expand(TESTS_SINGLE)
+ def test_single_input(self, test_data):
+ test_data = {"img": test_data}
+ result = ToTensord(keys="img", track_meta=True)(test_data)
+ out_img = result["img"]
+ self.assertTrue(isinstance(out_img, torch.Tensor))
+ assert_allclose(out_img, test_data["img"], type_test=False)
+ self.assertEqual(out_img.ndim, 0)
+
+ @unittest.skipUnless(HAS_CUPY, "CuPy is required.")
+ def test_cupy(self):
+ test_data = [[1, 2], [3, 4]]
+ cupy_array = cp.ascontiguousarray(cp.asarray(test_data))
+ result = ToTensord(keys="img")({"img": cupy_array})
+ self.assertTrue(isinstance(result["img"], torch.Tensor))
+ assert_allclose(result["img"], test_data, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py
index 98b300eeace..d7341bc71e2 100644
--- a/tests/test_torchvision_fc_model.py
+++ b/tests/test_torchvision_fc_model.py
@@ -16,10 +16,13 @@
from parameterized import parameterized
from monai.networks import eval_mode
-from monai.networks.nets import TorchVisionFCModel
-from monai.utils import optional_import
+from monai.networks.nets import TorchVisionFCModel, UNet
+from monai.networks.utils import look_up_named_module, set_named_module
+from monai.utils import min_version, optional_import
-_, has_tv = optional_import("torchvision")
+Inception_V3_Weights, has_enum = optional_import("torchvision.models.inception", name="Inception_V3_Weights")
+
+_, has_tv = optional_import("torchvision", "0.12", min_version)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -71,6 +74,25 @@
(2, 5),
]
+TEST_CASE_7 = [
+ {
+ "model_name": "inception_v3",
+ "num_classes": 5,
+ "use_conv": True,
+ "pool": "",
+ "in_channels": 2048,
+ "node_name": "Mixed_7c.cat_2",
+ },
+ (2, 3, 299, 299),
+ (2, 5, 8, 8),
+]
+
+TEST_CASE_8 = [
+ {"model_name": "vit_b_16", "num_classes": 5, "in_channels": 768, "pool": None, "fc_name": "heads.head"},
+ (2, 3, 224, 224),
+ (2, 5),
+]
+
TEST_CASE_PRETRAINED_0 = [
{"model_name": "resnet18", "num_classes": 1, "use_conv": True, "pretrained": True},
(2, 3, 224, 224),
@@ -113,9 +135,25 @@
-0.010419349186122417,
]
+TEST_CASE_PRETRAINED_6 = [
+ {
+ "model_name": "inception_v3",
+ "num_classes": 5,
+ "use_conv": False,
+ "pool": None,
+ "weights": Inception_V3_Weights.IMAGENET1K_V1 if has_enum else None,
+ },
+ (2, 3, 299, 299),
+ (2, 5),
+ -0.21029122173786163,
+]
+
class TestTorchVisionFCModel(unittest.TestCase):
- @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
+ @parameterized.expand(
+ [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]
+ + ([TEST_CASE_8] if has_enum else [])
+ )
@skipUnless(has_tv, "Requires TorchVision.")
def test_without_pretrained(self, input_param, input_shape, expected_shape):
net = TorchVisionFCModel(**input_param).to(device)
@@ -132,16 +170,28 @@ def test_without_pretrained(self, input_param, input_shape, expected_shape):
TEST_CASE_PRETRAINED_4,
TEST_CASE_PRETRAINED_5,
]
+ + ([TEST_CASE_PRETRAINED_6] if has_enum else [])
)
@skipUnless(has_tv, "Requires TorchVision.")
def test_with_pretrained(self, input_param, input_shape, expected_shape, expected_value):
net = TorchVisionFCModel(**input_param).to(device)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
- value = next(net.parameters())[0, 0, 0, 0].item()
+ value = next(net.features.parameters())[0, 0, 0, 0].item()
self.assertEqual(value, expected_value)
self.assertEqual(result.shape, expected_shape)
+class TestLookup(unittest.TestCase):
+ def test_get_module(self):
+ net = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(4, 8, 16, 32, 64), strides=(2, 2, 2, 2))
+ self.assertEqual(look_up_named_module("", net), net)
+ mod = look_up_named_module("model.1.submodule.1.submodule.1.submodule.0.conv", net)
+ self.assertTrue(str(mod).startswith("Conv2d"))
+ self.assertIsInstance(set_named_module(net, "model", torch.nn.Identity()).model, torch.nn.Identity)
+ self.assertEqual(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net), None)
+ self.assertEqual(look_up_named_module("test attribute", net), None)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_torchvisiond.py b/tests/test_torchvisiond.py
index def26fa26bb..b72c6f86f12 100644
--- a/tests/test_torchvisiond.py
+++ b/tests/test_torchvisiond.py
@@ -16,6 +16,7 @@
from monai.transforms import TorchVisiond
from monai.utils import set_determinism
+from tests.utils import assert_allclose
TEST_CASE_1 = [
{"keys": "img", "name": "ColorJitter"},
@@ -53,7 +54,7 @@ class TestTorchVisiond(unittest.TestCase):
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
result = TorchVisiond(**input_param)(input_data)
- torch.testing.assert_allclose(result["img"], expected_value)
+ assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False)
if __name__ == "__main__":
diff --git a/tests/test_transform.py b/tests/test_transform.py
new file mode 100644
index 00000000000..a6c50011476
--- /dev/null
+++ b/tests/test_transform.py
@@ -0,0 +1,57 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+
+import monai.transforms as mt
+from monai.data import Dataset
+
+
+class FaultyTransform(mt.Transform):
+ def __call__(self, _):
+ raise RuntimeError
+
+
+def faulty_lambda(_):
+ raise RuntimeError
+
+
+class TestTransform(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ super(__class__, cls).setUpClass()
+ cls.orig_value = os.environ.get("MONAI_DEBUG")
+
+ @classmethod
+ def tearDownClass(cls):
+ if cls.orig_value is not None:
+ os.environ["MONAI_DEBUG"] = cls.orig_value
+ else:
+ os.environ.pop("MONAI_DEBUG")
+ super(__class__, cls).tearDownClass()
+
+ def test_raise(self):
+ for transform in (FaultyTransform(), mt.Lambda(faulty_lambda)):
+ ds = Dataset([None] * 10, transform)
+ for debug in ("False", "True"):
+ os.environ["MONAI_DEBUG"] = debug
+ try:
+ ds[0]
+ except RuntimeError as re:
+ if debug == "False":
+ self.assertTrue("applying transform" in str(re))
+ else:
+ self.assertFalse("applying transform" in str(re))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_unetr_block.py b/tests/test_unetr_block.py
index c0f14c829da..8a4ee3a1635 100644
--- a/tests/test_unetr_block.py
+++ b/tests/test_unetr_block.py
@@ -67,7 +67,6 @@
]
TEST_UP_BLOCK.append(test_case)
-
TEST_PRUP_BLOCK = []
in_channels, out_channels = 4, 2
for spatial_dims in range(1, 4):
diff --git a/tests/test_upsample_block.py b/tests/test_upsample_block.py
index 6e2e548042b..535ad80c119 100644
--- a/tests/test_upsample_block.py
+++ b/tests/test_upsample_block.py
@@ -19,46 +19,46 @@
from monai.utils import UpsampleMode
TEST_CASES = [
- [{"dimensions": 2, "in_channels": 4}, (7, 4, 32, 48), (7, 4, 64, 96)], # 4-channel 2D, batch 7
- [{"dimensions": 1, "in_channels": 4, "out_channels": 3}, (16, 4, 63), (16, 3, 126)], # 4-channel 1D, batch 16
+ [{"spatial_dims": 2, "in_channels": 4}, (7, 4, 32, 48), (7, 4, 64, 96)], # 4-channel 2D, batch 7
+ [{"spatial_dims": 1, "in_channels": 4, "out_channels": 3}, (16, 4, 63), (16, 3, 126)], # 4-channel 1D, batch 16
[
- {"dimensions": 1, "in_channels": 4, "out_channels": 8, "mode": "deconv", "align_corners": False},
+ {"spatial_dims": 1, "in_channels": 4, "out_channels": 8, "mode": "deconv", "align_corners": False},
(16, 4, 20),
(16, 8, 40),
], # 4-channel 1D, batch 16
[
- {"dimensions": 3, "in_channels": 4, "mode": "nontrainable"},
+ {"spatial_dims": 3, "in_channels": 4, "mode": "nontrainable"},
(16, 4, 32, 24, 48),
(16, 4, 64, 48, 96),
], # 4-channel 3D, batch 16
[
- {"dimensions": 3, "in_channels": 4, "mode": "nontrainable", "size": 64},
+ {"spatial_dims": 3, "in_channels": 4, "mode": "nontrainable", "size": 64},
(16, 4, 32, 24, 48),
(16, 4, 64, 64, 64),
], # 4-channel 3D, batch 16
[
- {"dimensions": 3, "in_channels": 4, "mode": "nontrainable", "size": (64, 24, 48)},
+ {"spatial_dims": 3, "in_channels": 4, "mode": "nontrainable", "size": (64, 24, 48)},
(16, 4, 32, 24, 48),
(16, 4, 64, 24, 48),
], # 4-channel 3D, batch 16
[
- {"dimensions": 3, "in_channels": 1, "mode": "deconv", "scale_factor": 3, "align_corners": False},
+ {"spatial_dims": 3, "in_channels": 1, "mode": "deconv", "scale_factor": 3, "align_corners": False},
(16, 1, 10, 15, 20),
(16, 1, 30, 45, 60),
], # 1-channel 3D, batch 16
[
- {"dimensions": 3, "in_channels": 1, "mode": "pixelshuffle", "scale_factor": 2, "align_corners": False},
+ {"spatial_dims": 3, "in_channels": 1, "mode": "pixelshuffle", "scale_factor": 2, "align_corners": False},
(16, 1, 10, 15, 20),
(16, 1, 20, 30, 40),
], # 1-channel 3D, batch 16
[
- {"dimensions": 2, "in_channels": 4, "mode": "pixelshuffle", "scale_factor": 2},
+ {"spatial_dims": 2, "in_channels": 4, "mode": "pixelshuffle", "scale_factor": 2},
(16, 4, 10, 15),
(16, 4, 20, 30),
], # 4-channel 2D, batch 16
[
{
- "dimensions": 3,
+ "spatial_dims": 3,
"mode": "pixelshuffle",
"scale_factor": 2,
"align_corners": False,
@@ -67,6 +67,16 @@
(16, 1, 10, 15, 20),
(16, 3, 20, 30, 40),
], # 1-channel 3D, batch 16, pre_conv
+ [
+ {"spatial_dims": 3, "in_channels": 8, "out_channels": 4, "mode": "deconvgroup"},
+ (16, 8, 16, 16, 16),
+ (16, 4, 32, 32, 32),
+ ], # 8-channel 3D, batch 16
+ [
+ {"spatial_dims": 2, "in_channels": 32, "out_channels": 16, "mode": "deconvgroup", "scale_factor": 2},
+ (8, 32, 16, 16),
+ (8, 16, 32, 32),
+ ], # 32-channel 2D, batch 8
]
TEST_CASES_EQ = []
@@ -74,20 +84,47 @@
expected_shape = (16, 5, 4 * s, 5 * s, 6 * s)
for t in UpsampleMode:
test_case = [
- {"dimensions": 3, "in_channels": 3, "out_channels": 5, "mode": t, "scale_factor": s, "align_corners": True},
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 5,
+ "mode": t,
+ "scale_factor": s,
+ "align_corners": True,
+ },
(16, 3, 4, 5, 6),
expected_shape,
]
TEST_CASES_EQ.append(test_case)
+TEST_CASES_EQ2 = [] # type: ignore
+for s in range(2, 5):
+ for k in range(1, 7):
+ expected_shape = (16, 5, 4 * s, 5 * s, 6 * s)
+ for t in UpsampleMode:
+ test_case = [
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 5,
+ "mode": t,
+ "scale_factor": s,
+ "kernel_size": k,
+ "align_corners": False,
+ },
+ (16, 3, 4, 5, 6),
+ expected_shape,
+ ]
+ TEST_CASES_EQ.append(test_case)
+
class TestUpsample(unittest.TestCase):
- @parameterized.expand(TEST_CASES + TEST_CASES_EQ)
+ @parameterized.expand(TEST_CASES + TEST_CASES_EQ + TEST_CASES_EQ2)
def test_shape(self, input_param, input_shape, expected_shape):
net = UpSample(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
- self.assertEqual(result.shape, expected_shape)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
if __name__ == "__main__":
diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py
index fa4d10b402f..7041a09f52e 100644
--- a/tests/test_utils_pytorch_numpy_unification.py
+++ b/tests/test_utils_pytorch_numpy_unification.py
@@ -12,11 +12,12 @@
import unittest
import numpy as np
+import torch
from parameterized import parameterized
from monai.transforms.utils_pytorch_numpy_unification import mode, percentile
from monai.utils import set_determinism
-from tests.utils import TEST_NDARRAYS, assert_allclose
+from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick
TEST_MODE = []
for p in TEST_NDARRAYS:
@@ -39,6 +40,17 @@ def test_percentile(self):
results.append(percentile(arr, q))
assert_allclose(results[0], results[-1], type_test=False, atol=1e-4, rtol=1e-4)
+ @skip_if_quick
+ def test_many_elements_quantile(self): # pytorch#64947
+ for p in TEST_NDARRAYS:
+ for elements in (1000, 17_000_000):
+ for t in [*TEST_NDARRAYS, list]:
+ x = p(np.random.randn(elements))
+ q = percentile(x, t([10, 50]))
+ if isinstance(x, torch.Tensor):
+ self.assertIsInstance(q, torch.Tensor)
+ assert_allclose(q.shape, [2], type_test=False)
+
def test_fails(self):
for p in TEST_NDARRAYS:
for q in (-1, 101):
diff --git a/tests/test_varautoencoder.py b/tests/test_varautoencoder.py
index 04fc07f53fc..a6315ebc63a 100644
--- a/tests/test_varautoencoder.py
+++ b/tests/test_varautoencoder.py
@@ -75,7 +75,34 @@
(1, 3, 128, 128, 128),
]
-CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]
+TEST_CASE_4 = [ # 4-channel 1D, batch 4
+ {
+ "spatial_dims": 1,
+ "in_shape": (4, 128),
+ "out_channels": 3,
+ "latent_size": 2,
+ "channels": (4, 8, 16),
+ "strides": (2, 2, 2),
+ },
+ (1, 4, 128),
+ (1, 3, 128),
+]
+
+TEST_CASE_5 = [ # 4-channel 1D, batch 4, use_sigmoid = False
+ {
+ "spatial_dims": 1,
+ "in_shape": (4, 128),
+ "out_channels": 3,
+ "latent_size": 2,
+ "channels": (4, 8, 16),
+ "strides": (2, 2, 2),
+ "use_sigmoid": False,
+ },
+ (1, 4, 128),
+ (1, 3, 128),
+]
+
+CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]
class TestVarAutoEncoder(unittest.TestCase):
diff --git a/tests/test_varnet.py b/tests/test_varnet.py
new file mode 100644
index 00000000000..c715e7d37fb
--- /dev/null
+++ b/tests/test_varnet.py
@@ -0,0 +1,61 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.reconstruction.networks.nets.coil_sensitivity_model import CoilSensitivityModel
+from monai.apps.reconstruction.networks.nets.complex_unet import ComplexUnet
+from monai.apps.reconstruction.networks.nets.varnet import VariationalNetworkModel
+from monai.networks import eval_mode
+from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+coil_sens_model = CoilSensitivityModel(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])
+refinement_model = ComplexUnet(spatial_dims=2, features=[8, 16, 32, 64, 128, 8])
+num_cascades = 2
+TESTS = []
+TESTS.append([coil_sens_model, refinement_model, num_cascades, (1, 3, 50, 50, 2), (1, 50, 50)]) # batch=1
+TESTS.append([coil_sens_model, refinement_model, num_cascades, (2, 3, 50, 50, 2), (2, 50, 50)]) # batch=2
+
+
+class TestVarNet(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_shape(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):
+ net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades).to(device)
+ mask_shape = [1 for _ in input_shape]
+ mask_shape[-2] = input_shape[-2]
+ mask = torch.zeros(mask_shape)
+ mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1
+
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device), mask.bool().to(device))
+ self.assertEqual(result.shape, expected_shape)
+
+ @parameterized.expand(TESTS)
+ @SkipIfBeforePyTorchVersion((1, 9, 1))
+ def test_script(self, coil_sens_model, refinement_model, num_cascades, input_shape, expected_shape):
+ net = VariationalNetworkModel(coil_sens_model, refinement_model, num_cascades)
+
+ mask_shape = [1 for _ in input_shape]
+ mask_shape[-2] = input_shape[-2]
+ mask = torch.zeros(mask_shape)
+ mask[..., mask_shape[-2] // 2 - 5 : mask_shape[-2] // 2 + 5, :] = 1
+
+ test_data = torch.randn(input_shape)
+
+ test_script_save(net, test_data, mask.bool())
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_version_leq.py b/tests/test_version_leq.py
index 86fccca9fb2..725c1ee128d 100644
--- a/tests/test_version_leq.py
+++ b/tests/test_version_leq.py
@@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import unittest
diff --git a/tests/test_video_datasets.py b/tests/test_video_datasets.py
new file mode 100644
index 00000000000..78e015e350e
--- /dev/null
+++ b/tests/test_video_datasets.py
@@ -0,0 +1,144 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import unittest
+from typing import Type, Union
+
+import torch
+
+import monai.transforms as mt
+from monai.data.dataloader import DataLoader
+from monai.data.video_dataset import CameraDataset, VideoDataset, VideoFileDataset
+from monai.utils.module import optional_import
+from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config
+
+cv2, has_cv2 = optional_import("cv2")
+
+NUM_CAPTURE_DEVICES = CameraDataset.get_num_devices()
+TRANSFORMS = mt.Compose(
+ [mt.EnsureChannelFirst(True, "no_channel"), mt.DivisiblePad(16), mt.ScaleIntensity(), mt.CastToType(torch.float32)]
+)
+
+
+class Base:
+ class TestVideoDataset(unittest.TestCase):
+ video_source: Union[int, str]
+ ds: Type[VideoDataset]
+
+ def get_video_source(self):
+ return self.video_source
+
+ def get_ds(self, *args, **kwargs) -> VideoDataset:
+ return self.ds(video_source=self.get_video_source(), transform=TRANSFORMS, *args, **kwargs) # type: ignore
+
+ @unittest.skipIf(has_cv2, "Only tested when OpenCV not installed.")
+ def test_no_opencv_raises(self):
+ with self.assertRaises(RuntimeError):
+ _ = self.get_ds(max_num_frames=10)
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_multiprocessing(self):
+ for num_workers in (0, 2):
+ multiprocessing = num_workers > 0
+ ds = self.get_ds(max_num_frames=100, multiprocessing=multiprocessing)
+ dl = DataLoader(ds, num_workers=num_workers, batch_size=2)
+ _ = next(iter(dl))
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_multiple_sources(self, should_match: bool = True):
+ ds1 = self.get_ds()
+ ds2 = self.get_ds()
+ if should_match:
+ assert_allclose(ds1.get_frame(), ds2.get_frame())
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_dataset(self, known_num_frames=None, known_fps=None):
+ num_frames = (10,) if known_num_frames is None else (10, None)
+ for max_num_frames in num_frames:
+ ds = self.get_ds(max_num_frames=max_num_frames)
+ if known_fps is not None:
+ self.assertEqual(ds.get_fps(), known_fps)
+ frames = list(ds)
+ if max_num_frames is not None:
+ self.assertEqual(len(frames), max_num_frames)
+ elif known_num_frames is not None:
+ self.assertEqual(len(frames), len(ds))
+ for f in frames:
+ self.assertTupleEqual(f.shape, frames[0].shape)
+
+
+@unittest.skipIf(NUM_CAPTURE_DEVICES == 0, "At least one capture device required.")
+class TestCameraDataset(Base.TestVideoDataset):
+ video_source = 0
+ ds = CameraDataset
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_multiple_sources(self):
+ super().test_multiple_sources(should_match=False)
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_device_out_of_range(self):
+ capture_device = NUM_CAPTURE_DEVICES + 1
+ with self.assertRaises(RuntimeError):
+ _ = CameraDataset(capture_device, TRANSFORMS, 0)
+
+
+class TestVideoFileDataset(Base.TestVideoDataset):
+ ds = VideoFileDataset
+
+ @classmethod
+ def setUpClass(cls):
+ super(__class__, cls).setUpClass()
+ codecs = VideoFileDataset.get_available_codecs()
+ if ".mp4" in codecs.values():
+ fname = "endo.mp4"
+ config = testing_data_config("videos", "endovis")
+ cls.known_fps = 2.0
+ cls.known_num_frames = 23
+ elif ".avi" in codecs.values():
+ fname = "ultrasound.avi"
+ config = testing_data_config("videos", "ultrasound")
+ cls.known_fps = 2.0
+ cls.known_num_frames = 523
+ else:
+ cls.known_fps = None
+ cls.known_num_frames = None
+ cls.video_source = None
+ return
+ cls.video_source = os.path.join(os.path.dirname(__file__), "testing_data", fname)
+ download_url_or_skip_test(
+ url=config["url"],
+ filepath=cls.video_source,
+ hash_val=config.get("hash_val"),
+ hash_type=config.get("hash_type", "sha256"),
+ )
+
+ @unittest.skipUnless(has_cv2, "OpenCV required.")
+ def test_dataset(self):
+ super().test_dataset(self.known_num_frames, self.known_fps)
+ self.assertEqual(self.get_ds().get_num_frames(), self.known_num_frames)
+
+ def test_available_codecs(self):
+ codecs = VideoFileDataset.get_available_codecs()
+ if not has_cv2:
+ self.assertEqual(codecs, {})
+ else:
+ self.assertGreaterEqual(len(codecs), 0)
+
+ def get_video_source(self):
+ if self.video_source is None:
+ raise unittest.SkipTest("missing required codecs")
+ return super().get_video_source()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vis_gradbased.py b/tests/test_vis_gradbased.py
index 7655ca661e7..5af87698720 100644
--- a/tests/test_vis_gradbased.py
+++ b/tests/test_vis_gradbased.py
@@ -17,32 +17,47 @@
from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
from monai.visualize import GuidedBackpropGrad, GuidedBackpropSmoothGrad, SmoothGrad, VanillaGrad
+
+class DenseNetAdjoint(DenseNet121):
+ def __call__(self, x, adjoint_info):
+ if adjoint_info != 42:
+ raise ValueError
+ return super().__call__(x)
+
+
DENSENET2D = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
DENSENET3D = DenseNet(spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,))
SENET2D = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
SENET3D = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
+DENSENET2DADJOINT = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)
TESTS = []
for type in (VanillaGrad, SmoothGrad, GuidedBackpropGrad, GuidedBackpropSmoothGrad):
# 2D densenet
- TESTS.append([type, DENSENET2D, (1, 1, 48, 64), (1, 1, 48, 64)])
+ TESTS.append([type, DENSENET2D, (1, 1, 48, 64)])
# 3D densenet
- TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6), (1, 1, 6, 6, 6)])
+ TESTS.append([type, DENSENET3D, (1, 1, 6, 6, 6)])
# 2D senet
- TESTS.append([type, SENET2D, (1, 3, 64, 64), (1, 1, 64, 64)])
+ TESTS.append([type, SENET2D, (1, 3, 64, 64)])
# 3D senet
- TESTS.append([type, SENET3D, (1, 3, 8, 8, 48), (1, 1, 8, 8, 48)])
+ TESTS.append([type, SENET3D, (1, 3, 8, 8, 48)])
+ # 2D densenet - adjoint
+ TESTS.append([type, DENSENET2DADJOINT, (1, 1, 48, 64)])
class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
- def test_shape(self, vis_type, model, shape, expected_shape):
+ def test_shape(self, vis_type, model, shape):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
+
+ # optionally test for adjoint info
+ kwargs = {"adjoint_info": 42} if isinstance(model, DenseNetAdjoint) else {}
+
model.to(device)
model.eval()
vis = vis_type(model)
x = torch.rand(shape, device=device)
- result = vis(x)
+ result = vis(x, **kwargs)
self.assertTupleEqual(result.shape, x.shape)
diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py
index 755f4d49ae7..d81007aa15f 100644
--- a/tests/test_vis_gradcam.py
+++ b/tests/test_vis_gradcam.py
@@ -10,80 +10,136 @@
# limitations under the License.
import unittest
+from typing import Any, List
import numpy as np
import torch
from parameterized import parameterized
from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
-from monai.visualize import GradCAM
-
-# 2D
-TEST_CASE_0 = [
- {
- "model": "densenet2d",
- "shape": (2, 1, 48, 64),
- "feature_shape": (2, 1, 1, 2),
- "target_layers": "class_layers.relu",
- },
- (2, 1, 48, 64),
-]
-# 3D
-TEST_CASE_1 = [
- {
- "model": "densenet3d",
- "shape": (2, 1, 6, 6, 6),
- "feature_shape": (2, 1, 2, 2, 2),
- "target_layers": "class_layers.relu",
- },
- (2, 1, 6, 6, 6),
-]
-# 2D
-TEST_CASE_2 = [
- {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"},
- (2, 1, 64, 64),
-]
-
-# 3D
-TEST_CASE_3 = [
- {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"},
- (2, 1, 8, 8, 48),
-]
+from monai.visualize import GradCAM, GradCAMpp
+from tests.utils import assert_allclose
+
+
+class DenseNetAdjoint(DenseNet121):
+ def __call__(self, x, adjoint_info):
+ if adjoint_info != 42:
+ raise ValueError
+ return super().__call__(x)
+
+
+TESTS: List[Any] = []
+TESTS_ILL: List[Any] = []
+
+for cam in (GradCAM, GradCAMpp):
+ # 2D
+ TESTS.append(
+ [
+ cam,
+ {
+ "model": "densenet2d",
+ "shape": (2, 1, 48, 64),
+ "feature_shape": (2, 1, 1, 2),
+ "target_layers": "class_layers.relu",
+ },
+ (2, 1, 48, 64),
+ ]
+ )
+ # 3D
+ TESTS.append(
+ [
+ cam,
+ {
+ "model": "densenet3d",
+ "shape": (2, 1, 6, 6, 6),
+ "feature_shape": (2, 1, 2, 2, 2),
+ "target_layers": "class_layers.relu",
+ },
+ (2, 1, 6, 6, 6),
+ ]
+ )
+ # 2D
+ TESTS.append(
+ [
+ cam,
+ {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"},
+ (2, 1, 64, 64),
+ ]
+ )
+
+ # 3D
+ TESTS.append(
+ [
+ cam,
+ {
+ "model": "senet3d",
+ "shape": (2, 3, 8, 8, 48),
+ "feature_shape": (2, 1, 1, 1, 2),
+ "target_layers": "layer4",
+ },
+ (2, 1, 8, 8, 48),
+ ]
+ )
+
+ # adjoint info
+ TESTS.append(
+ [
+ cam,
+ {
+ "model": "adjoint",
+ "shape": (2, 1, 48, 64),
+ "feature_shape": (2, 1, 1, 2),
+ "target_layers": "class_layers.relu",
+ },
+ (2, 1, 48, 64),
+ ]
+ )
+
+ TESTS_ILL.append([cam])
class TestGradientClassActivationMap(unittest.TestCase):
- @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
- def test_shape(self, input_data, expected_shape):
+ @parameterized.expand(TESTS)
+ def test_shape(self, cam_class, input_data, expected_shape):
if input_data["model"] == "densenet2d":
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
- if input_data["model"] == "densenet3d":
+ elif input_data["model"] == "densenet3d":
model = DenseNet(
spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)
)
- if input_data["model"] == "senet2d":
+ elif input_data["model"] == "senet2d":
model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
- if input_data["model"] == "senet3d":
+ elif input_data["model"] == "senet3d":
model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
+ elif input_data["model"] == "adjoint":
+ model = DenseNetAdjoint(spatial_dims=2, in_channels=1, out_channels=3)
+
+ # optionally test for adjoint info
+ kwargs = {"adjoint_info": 42} if input_data["model"] == "adjoint" else {}
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
- cam = GradCAM(nn_module=model, target_layers=input_data["target_layers"])
+ cam = cam_class(nn_module=model, target_layers=input_data["target_layers"])
image = torch.rand(input_data["shape"], device=device)
- result = cam(x=image, layer_idx=-1)
- np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), model(image).max(1)[-1].cpu())
- fea_shape = cam.feature_map_size(input_data["shape"], device=device)
+ inferred = model(image, **kwargs).max(1)[-1].cpu()
+ result = cam(x=image, layer_idx=-1, **kwargs)
+ np.testing.assert_array_equal(cam.nn_module.class_idx.cpu(), inferred)
+
+ fea_shape = cam.feature_map_size(input_data["shape"], device=device, **kwargs)
self.assertTupleEqual(fea_shape, input_data["feature_shape"])
self.assertTupleEqual(result.shape, expected_shape)
# check result is same whether class_idx=None is used or not
- result2 = cam(x=image, layer_idx=-1, class_idx=model(image).max(1)[-1].cpu())
- torch.testing.assert_allclose(result, result2)
+ result2 = cam(x=image, layer_idx=-1, class_idx=inferred, **kwargs)
+ assert_allclose(result, result2)
- def test_ill(self):
+ @parameterized.expand(TESTS_ILL)
+ def test_ill(self, cam_class):
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
for name, x in model.named_parameters():
if "features" in name:
x.requires_grad = False
- cam = GradCAM(nn_module=model, target_layers="class_layers.relu")
+ cam = cam_class(nn_module=model, target_layers="class_layers.relu")
image = torch.rand((2, 1, 48, 64))
with self.assertRaises(IndexError):
cam(x=image)
diff --git a/tests/test_vis_gradcampp.py b/tests/test_vis_gradcampp.py
deleted file mode 100644
index a261b6055b0..00000000000
--- a/tests/test_vis_gradcampp.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import torch
-from parameterized import parameterized
-
-from monai.networks.nets import DenseNet, DenseNet121, SEResNet50
-from monai.visualize import GradCAMpp
-
-# 2D
-TEST_CASE_0 = [
- {
- "model": "densenet2d",
- "shape": (2, 1, 48, 64),
- "feature_shape": (2, 1, 1, 2),
- "target_layers": "class_layers.relu",
- },
- (2, 1, 48, 64),
-]
-# 3D
-TEST_CASE_1 = [
- {
- "model": "densenet3d",
- "shape": (2, 1, 6, 6, 6),
- "feature_shape": (2, 1, 2, 2, 2),
- "target_layers": "class_layers.relu",
- },
- (2, 1, 6, 6, 6),
-]
-# 2D
-TEST_CASE_2 = [
- {"model": "senet2d", "shape": (2, 3, 64, 64), "feature_shape": (2, 1, 2, 2), "target_layers": "layer4"},
- (2, 1, 64, 64),
-]
-
-# 3D
-TEST_CASE_3 = [
- {"model": "senet3d", "shape": (2, 3, 8, 8, 48), "feature_shape": (2, 1, 1, 1, 2), "target_layers": "layer4"},
- (2, 1, 8, 8, 48),
-]
-
-
-class TestGradientClassActivationMapPP(unittest.TestCase):
- @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
- def test_shape(self, input_data, expected_shape):
- if input_data["model"] == "densenet2d":
- model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
- if input_data["model"] == "densenet3d":
- model = DenseNet(
- spatial_dims=3, in_channels=1, out_channels=3, init_features=2, growth_rate=2, block_config=(6,)
- )
- if input_data["model"] == "senet2d":
- model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
- if input_data["model"] == "senet3d":
- model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
- model.to(device)
- model.eval()
- cam = GradCAMpp(nn_module=model, target_layers=input_data["target_layers"])
- image = torch.rand(input_data["shape"], device=device)
- result = cam(x=image, layer_idx=-1)
- fea_shape = cam.feature_map_size(input_data["shape"], device=device)
- self.assertTupleEqual(fea_shape, input_data["feature_shape"])
- self.assertTupleEqual(result.shape, expected_shape)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_watershed.py b/tests/test_watershed.py
new file mode 100644
index 00000000000..705ddce8179
--- /dev/null
+++ b/tests/test_watershed.py
@@ -0,0 +1,58 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.array import (
+ GenerateDistanceMap,
+ GenerateInstanceBorder,
+ GenerateWatershedMarkers,
+ GenerateWatershedMask,
+ Watershed,
+)
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+np.random.RandomState(123)
+
+TESTS = []
+params = {"connectivity": 1}
+for p in TEST_NDARRAYS:
+ image = p(np.random.rand(1, 10, 10))
+ hover_map = p(np.random.rand(2, 10, 10))
+
+ TESTS.append([params, image, hover_map, (1, 10, 10)])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestWatershed(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_output(self, args, image, hover_map, expected_shape):
+ mask = GenerateWatershedMask()(image)
+ border_map = GenerateInstanceBorder(kernel_size=3)(mask, hover_map)
+ distance_map = GenerateDistanceMap()(mask, border_map)
+ markers = GenerateWatershedMarkers()(mask, border_map)
+
+ calculate_instance_seg = Watershed(**args)
+ output = calculate_instance_seg(distance_map, mask, markers)
+
+ self.assertTupleEqual(output.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_watershedd.py b/tests/test_watershedd.py
new file mode 100644
index 00000000000..6474759de42
--- /dev/null
+++ b/tests/test_watershedd.py
@@ -0,0 +1,68 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.apps.pathology.transforms.post.dictionary import (
+ GenerateDistanceMapd,
+ GenerateInstanceBorderd,
+ GenerateWatershedMarkersd,
+ GenerateWatershedMaskd,
+ Watershedd,
+)
+from monai.transforms import Compose
+from monai.utils import min_version, optional_import
+from tests.utils import TEST_NDARRAYS
+
+_, has_skimage = optional_import("skimage", "0.19.3", min_version)
+_, has_scipy = optional_import("scipy", "1.8.1", min_version)
+
+TESTS = []
+params = {"keys": "dist", "mask_key": "mask", "markers_key": "markers", "connectivity": 1}
+for p in TEST_NDARRAYS:
+ image = p(np.random.rand(1, 10, 10))
+ hover_map = p(np.random.rand(2, 10, 10))
+
+ TESTS.append([params, image, hover_map, (1, 10, 10)])
+
+ params.update({"markers_key": None})
+ TESTS.append([params, image, hover_map, (1, 10, 10)])
+
+ params.update({"mask_key": None, "markers_key": None})
+ TESTS.append([params, image, hover_map, (1, 10, 10)])
+
+
+@unittest.skipUnless(has_skimage, "Requires scikit-image library.")
+@unittest.skipUnless(has_scipy, "Requires scipy library.")
+class TestWatershedd(unittest.TestCase):
+ @parameterized.expand(TESTS)
+ def test_output(self, args, image, hover_map, expected_shape):
+ data = {"output": image, "hover_map": hover_map}
+
+ trans = Compose(
+ [
+ GenerateWatershedMaskd(keys="output"),
+ GenerateInstanceBorderd(keys="mask", hover_map_key="hover_map", kernel_size=3),
+ GenerateDistanceMapd(keys="mask", border_key="border"),
+ GenerateWatershedMarkersd(keys="mask", border_key="border"),
+ Watershedd(**args),
+ ]
+ )
+
+ output = trans(data)
+ self.assertTupleEqual(output["dist"].shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_wsireader.py b/tests/test_wsireader.py
index a0a076b682e..5fb113e1934 100644
--- a/tests/test_wsireader.py
+++ b/tests/test_wsireader.py
@@ -14,16 +14,16 @@
from unittest import skipUnless
import numpy as np
-import torch
from numpy.testing import assert_array_equal
from parameterized import parameterized
from monai.data import DataLoader, Dataset
-from monai.data.image_reader import WSIReader
-from monai.transforms import Compose, FromMetaTensord, LoadImaged, ToTensord
-from monai.utils import first, optional_import
-from monai.utils.enums import PostFix
-from tests.utils import download_url_or_skip_test, testing_data_config
+from monai.data.image_reader import WSIReader as WSIReaderDeprecated
+from monai.data.wsi_reader import WSIReader
+from monai.transforms import Compose, LoadImaged, ToTensord
+from monai.utils import deprecated, first, optional_import
+from monai.utils.enums import PostFix, WSIPatchKeys
+from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config
cucim, has_cucim = optional_import("cucim")
has_cucim = has_cucim and hasattr(cucim, "CuImage")
@@ -44,19 +44,19 @@
TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)]
-TEST_CASE_1 = [
+TEST_CASE_DEP_1 = [
FILE_PATH,
{"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
np.array([[[246], [246]], [[246], [246]], [[246], [246]]]),
]
-TEST_CASE_2 = [
+TEST_CASE_DEP_2 = [
FILE_PATH,
{"location": (0, 0), "size": (2, 1), "level": 2},
np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
]
-TEST_CASE_3 = [
+TEST_CASE_DEP_3 = [
FILE_PATH,
{"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 2},
np.array(
@@ -67,18 +67,64 @@
),
]
-TEST_CASE_4 = [
+TEST_CASE_DEP_4 = [
FILE_PATH,
{"location": (0, 0), "size": (8, 8), "level": 2, "grid_shape": (2, 1), "patch_size": 1},
np.array([[[[239]], [[239]], [[239]]], [[[243]], [[243]], [[243]]]]),
]
-TEST_CASE_5 = [
+TEST_CASE_DEP_5 = [
FILE_PATH,
{"location": (HEIGHT - 2, WIDTH - 2), "level": 0, "grid_shape": (1, 1)},
np.array([[[239, 239], [239, 239]], [[239, 239], [239, 239]], [[237, 237], [237, 237]]]),
]
+TEST_CASE_1 = [
+ FILE_PATH,
+ {},
+ {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
+ np.array([[[246], [246]], [[246], [246]], [[246], [246]]]),
+]
+
+TEST_CASE_2 = [
+ FILE_PATH,
+ {},
+ {"location": (0, 0), "size": (2, 1), "level": 2},
+ np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
+]
+
+TEST_CASE_3 = [
+ FILE_PATH,
+ {"channel_dim": -1},
+ {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
+ np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), 0, -1),
+]
+
+TEST_CASE_4 = [
+ FILE_PATH,
+ {"channel_dim": 2},
+ {"location": (0, 0), "size": (2, 1), "level": 2},
+ np.moveaxis(np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), 0, -1),
+]
+
+TEST_CASE_5 = [
+ FILE_PATH,
+ {"level": 2},
+ {"location": (0, 0), "size": (2, 1)},
+ np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
+]
+
+TEST_CASE_MULTI_WSI = [
+ [FILE_PATH, FILE_PATH],
+ {"location": (0, 0), "size": (2, 1), "level": 2},
+ np.concatenate(
+ [
+ np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
+ np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
+ ],
+ axis=0,
+ ),
+]
TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW
@@ -89,6 +135,8 @@
TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels
TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color
+TEST_CASE_MPP_0 = [FILE_PATH, 0, (1000.0, 1000.0)]
+
def save_rgba_tiff(array: np.ndarray, filename: str, mode: str):
"""
@@ -108,6 +156,20 @@ def save_rgba_tiff(array: np.ndarray, filename: str, mode: str):
return filename
+def save_gray_tiff(array: np.ndarray, filename: str):
+ """
+ Save numpy array into a TIFF file
+
+ Args:
+ array: numpy ndarray with any shape
+ filename: the filename to be used for the tiff file.
+ """
+ img_gray = array
+ imwrite(filename, img_gray, shape=img_gray.shape)
+
+ return filename
+
+
@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
def setUpModule():
hash_type = testing_data_config("images", FILE_KEY, "hash_type")
@@ -115,21 +177,22 @@ def setUpModule():
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
-class WSIReaderTests:
+@deprecated(since="0.8", msg_suffix="use tests for `monai.wsi_reader.WSIReader` instead, `WSIReaderTests`.")
+class WSIReaderDeprecatedTests:
class Tests(unittest.TestCase):
backend = None
@parameterized.expand([TEST_CASE_0])
def test_read_whole_image(self, file_path, level, expected_shape):
- reader = WSIReader(self.backend, level=level)
+ reader = WSIReaderDeprecated(self.backend, level=level)
with reader.read(file_path) as img_obj:
img = reader.get_data(img_obj)[0]
self.assertTupleEqual(img.shape, expected_shape)
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_5])
+ @parameterized.expand([TEST_CASE_DEP_1, TEST_CASE_DEP_2, TEST_CASE_DEP_5])
def test_read_region(self, file_path, patch_info, expected_img):
kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
- reader = WSIReader(self.backend, **kwargs)
+ reader = WSIReaderDeprecated(self.backend, **kwargs)
with reader.read(file_path, **kwargs) as img_obj:
if self.backend == "tifffile":
with self.assertRaises(ValueError):
@@ -143,9 +206,9 @@ def test_read_region(self, file_path, patch_info, expected_img):
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
- @parameterized.expand([TEST_CASE_3, TEST_CASE_4])
+ @parameterized.expand([TEST_CASE_DEP_3, TEST_CASE_DEP_4])
def test_read_patches(self, file_path, patch_info, expected_img):
- reader = WSIReader(self.backend)
+ reader = WSIReaderDeprecated(self.backend)
with reader.read(file_path) as img_obj:
if self.backend == "tifffile":
with self.assertRaises(ValueError):
@@ -155,6 +218,111 @@ def test_read_patches(self, file_path, patch_info, expected_img):
self.assertTupleEqual(img.shape, expected_img.shape)
self.assertIsNone(assert_array_equal(img, expected_img))
+ @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
+ @skipUnless(has_tiff, "Requires tifffile.")
+ def test_read_rgba(self, img_expected):
+ # skip for OpenSlide since not working with images without tiles
+ if self.backend == "openslide":
+ return
+ image = {}
+ reader = WSIReaderDeprecated(self.backend)
+ for mode in ["RGB", "RGBA"]:
+ file_path = save_rgba_tiff(
+ img_expected,
+ os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"),
+ mode=mode,
+ )
+ with reader.read(file_path) as img_obj:
+ image[mode], _ = reader.get_data(img_obj)
+
+ self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
+ self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))
+
+ @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D])
+ @skipUnless(has_tiff, "Requires tifffile.")
+ def test_read_malformats(self, img_expected):
+ if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1):
+ # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230
+ return
+ reader = WSIReaderDeprecated(self.backend)
+ file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
+ imwrite(file_path, img_expected, shape=img_expected.shape)
+ with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)):
+ with reader.read(file_path) as img_obj:
+ reader.get_data(img_obj)
+
+ @parameterized.expand([TEST_CASE_TRANSFORM_0])
+ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape):
+ train_transform = Compose(
+ [
+ LoadImaged(keys=["image"], reader=WSIReaderDeprecated, backend=self.backend, level=level),
+ ToTensord(keys=["image"]),
+ ]
+ )
+ dataset = Dataset([{"image": file_path}], transform=train_transform)
+ data_loader = DataLoader(dataset)
+ data: dict = first(data_loader)
+ for s in data[PostFix.meta("image")]["spatial_shape"]:
+ assert_allclose(s, expected_spatial_shape, type_test=False)
+ self.assertTupleEqual(data["image"].shape, expected_shape)
+
+
+class WSIReaderTests:
+ class Tests(unittest.TestCase):
+ backend = None
+
+ @parameterized.expand([TEST_CASE_0])
+ def test_read_whole_image(self, file_path, level, expected_shape):
+ reader = WSIReader(self.backend, level=level)
+ with reader.read(file_path) as img_obj:
+ img, meta = reader.get_data(img_obj)
+ self.assertTupleEqual(img.shape, expected_shape)
+ self.assertEqual(meta["backend"], self.backend)
+ self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())
+ self.assertEqual(meta[WSIPatchKeys.LEVEL], level)
+ assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:])
+ assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0))
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
+ def test_read_region(self, file_path, kwargs, patch_info, expected_img):
+ reader = WSIReader(self.backend, **kwargs)
+ level = patch_info.get("level", kwargs.get("level"))
+ if self.backend == "tifffile" and level < 2:
+ return
+ with reader.read(file_path) as img_obj:
+ # Read twice to check multiple calls
+ img, meta = reader.get_data(img_obj, **patch_info)
+ img2 = reader.get_data(img_obj, **patch_info)[0]
+ self.assertTupleEqual(img.shape, img2.shape)
+ self.assertIsNone(assert_array_equal(img, img2))
+ self.assertTupleEqual(img.shape, expected_img.shape)
+ self.assertIsNone(assert_array_equal(img, expected_img))
+ self.assertEqual(meta["backend"], self.backend)
+ self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())
+ self.assertEqual(meta[WSIPatchKeys.LEVEL], level)
+ assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"])
+ assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"])
+
+ @parameterized.expand([TEST_CASE_MULTI_WSI])
+ def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):
+ kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
+ reader = WSIReader(self.backend, **kwargs)
+ img_obj_list = reader.read(file_path_list, **kwargs)
+ # Read twice to check multiple calls
+ img, meta = reader.get_data(img_obj_list, **patch_info)
+ img2 = reader.get_data(img_obj_list, **patch_info)[0]
+ for img_obj in img_obj_list:
+ img_obj.close()
+ self.assertTupleEqual(img.shape, img2.shape)
+ self.assertIsNone(assert_array_equal(img, img2))
+ self.assertTupleEqual(img.shape, expected_img.shape)
+ self.assertIsNone(assert_array_equal(img, expected_img))
+ self.assertEqual(meta["backend"], self.backend)
+ self.assertEqual(meta[WSIPatchKeys.PATH][0].lower(), str(os.path.abspath(file_path_list[0])).lower())
+ self.assertEqual(meta[WSIPatchKeys.LEVEL][0], patch_info["level"])
+ assert_array_equal(meta[WSIPatchKeys.SIZE][0], expected_img.shape[1:])
+ assert_array_equal(meta[WSIPatchKeys.LOCATION][0], patch_info["location"])
+
@parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
@skipUnless(has_tiff, "Requires tifffile.")
def test_read_rgba(self, img_expected):
@@ -193,7 +361,6 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte
train_transform = Compose(
[
LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
- FromMetaTensord(keys=["image"]),
ToTensord(keys=["image"]),
]
)
@@ -201,25 +368,100 @@ def test_with_dataloader(self, file_path, level, expected_spatial_shape, expecte
data_loader = DataLoader(dataset)
data: dict = first(data_loader)
for s in data[PostFix.meta("image")]["spatial_shape"]:
- torch.testing.assert_allclose(s, expected_spatial_shape)
+ assert_allclose(s, expected_spatial_shape, type_test=False)
self.assertTupleEqual(data["image"].shape, expected_shape)
+ @parameterized.expand([TEST_CASE_TRANSFORM_0])
+ def test_with_dataloader_batch(self, file_path, level, expected_spatial_shape, expected_shape):
+ train_transform = Compose(
+ [
+ LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
+ ToTensord(keys=["image"]),
+ ]
+ )
+ dataset = Dataset([{"image": file_path}, {"image": file_path}], transform=train_transform)
+ batch_size = 2
+ data_loader = DataLoader(dataset, batch_size=batch_size)
+ data: dict = first(data_loader)
+ for s in data[PostFix.meta("image")]["spatial_shape"]:
+ assert_allclose(s, expected_spatial_shape, type_test=False)
+ self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:]))
+
+ @parameterized.expand([TEST_CASE_0])
+ def test_read_whole_image_multi_thread(self, file_path, level, expected_shape):
+ if self.backend == "cucim":
+ reader = WSIReader(self.backend, level=level, num_workers=4)
+ with reader.read(file_path) as img_obj:
+ img, meta = reader.get_data(img_obj)
+ self.assertTupleEqual(img.shape, expected_shape)
+ self.assertEqual(meta["backend"], self.backend)
+ self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())
+ self.assertEqual(meta[WSIPatchKeys.LEVEL], level)
+ assert_array_equal(meta[WSIPatchKeys.SIZE], expected_shape[1:])
+ assert_array_equal(meta[WSIPatchKeys.LOCATION], (0, 0))
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
+ def test_read_region_multi_thread(self, file_path, kwargs, patch_info, expected_img):
+ if self.backend == "cucim":
+ reader = WSIReader(backend=self.backend, num_workers=2, **kwargs)
+ with reader.read(file_path) as img_obj:
+ # Read twice to check multiple calls
+ img, meta = reader.get_data(img_obj, **patch_info)
+ img2 = reader.get_data(img_obj, **patch_info)[0]
+ self.assertTupleEqual(img.shape, img2.shape)
+ self.assertIsNone(assert_array_equal(img, img2))
+ self.assertTupleEqual(img.shape, expected_img.shape)
+ self.assertIsNone(assert_array_equal(img, expected_img))
+ self.assertEqual(meta["backend"], self.backend)
+ self.assertEqual(meta[WSIPatchKeys.PATH].lower(), str(os.path.abspath(file_path)).lower())
+ self.assertEqual(meta[WSIPatchKeys.LEVEL], patch_info["level"])
+ assert_array_equal(meta[WSIPatchKeys.SIZE], patch_info["size"])
+ assert_array_equal(meta[WSIPatchKeys.LOCATION], patch_info["location"])
+
+ @parameterized.expand([TEST_CASE_MPP_0])
+ def test_resolution_mpp(self, file_path, level, expected_mpp):
+ reader = WSIReader(self.backend, level=level)
+ with reader.read(file_path) as img_obj:
+ mpp = reader.get_mpp(img_obj, level)
+ self.assertTupleEqual(mpp, expected_mpp)
+
@skipUnless(has_cucim, "Requires cucim")
-class TestCuCIM(WSIReaderTests.Tests):
+class TestCuCIMDeprecated(WSIReaderDeprecatedTests.Tests):
@classmethod
def setUpClass(cls):
cls.backend = "cucim"
@skipUnless(has_osl, "Requires OpenSlide")
-class TestOpenSlide(WSIReaderTests.Tests):
+class TestOpenSlideDeprecated(WSIReaderDeprecatedTests.Tests):
@classmethod
def setUpClass(cls):
cls.backend = "openslide"
@skipUnless(has_tiff, "Requires TiffFile")
+class TestTiffFileDeprecated(WSIReaderDeprecatedTests.Tests):
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = "tifffile"
+
+
+@skipUnless(has_cucim, "Requires cucim")
+class TestCuCIM(WSIReaderTests.Tests):
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = "cucim"
+
+
+@skipUnless(has_osl, "Requires openslide")
+class TestOpenSlide(WSIReaderTests.Tests):
+ @classmethod
+ def setUpClass(cls):
+ cls.backend = "openslide"
+
+
+@skipUnless(has_tiff, "Requires tifffile")
class TestTiffFile(WSIReaderTests.Tests):
@classmethod
def setUpClass(cls):
diff --git a/tests/test_wsireader_new.py b/tests/test_wsireader_new.py
deleted file mode 100644
index 0d5e5892e64..00000000000
--- a/tests/test_wsireader_new.py
+++ /dev/null
@@ -1,277 +0,0 @@
-# Copyright (c) MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import unittest
-from unittest import skipUnless
-
-import numpy as np
-from numpy.testing import assert_array_equal
-from parameterized import parameterized
-
-from monai.data import DataLoader, Dataset
-from monai.data.wsi_reader import WSIReader
-from monai.transforms import Compose, FromMetaTensord, LoadImaged, ToTensord
-from monai.utils import first, optional_import
-from monai.utils.enums import PostFix
-from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config
-
-cucim, has_cucim = optional_import("cucim")
-has_cucim = has_cucim and hasattr(cucim, "CuImage")
-openslide, has_osl = optional_import("openslide")
-imwrite, has_tiff = optional_import("tifffile", name="imwrite")
-_, has_codec = optional_import("imagecodecs")
-has_tiff = has_tiff and has_codec
-
-FILE_KEY = "wsi_img"
-FILE_URL = testing_data_config("images", FILE_KEY, "url")
-base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff"
-FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension)
-
-HEIGHT = 32914
-WIDTH = 46000
-
-TEST_CASE_0 = [FILE_PATH, 2, (3, HEIGHT // 4, WIDTH // 4)]
-
-TEST_CASE_TRANSFORM_0 = [FILE_PATH, 4, (HEIGHT // 16, WIDTH // 16), (1, 3, HEIGHT // 16, WIDTH // 16)]
-
-TEST_CASE_1 = [
- FILE_PATH,
- {},
- {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
- np.array([[[246], [246]], [[246], [246]], [[246], [246]]]),
-]
-
-TEST_CASE_2 = [
- FILE_PATH,
- {},
- {"location": (0, 0), "size": (2, 1), "level": 2},
- np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
-]
-
-TEST_CASE_3 = [
- FILE_PATH,
- {"channel_dim": -1},
- {"location": (HEIGHT // 2, WIDTH // 2), "size": (2, 1), "level": 0},
- np.moveaxis(np.array([[[246], [246]], [[246], [246]], [[246], [246]]]), 0, -1),
-]
-
-TEST_CASE_4 = [
- FILE_PATH,
- {"channel_dim": 2},
- {"location": (0, 0), "size": (2, 1), "level": 2},
- np.moveaxis(np.array([[[239], [239]], [[239], [239]], [[239], [239]]]), 0, -1),
-]
-
-TEST_CASE_MULTI_WSI = [
- [FILE_PATH, FILE_PATH],
- {"location": (0, 0), "size": (2, 1), "level": 2},
- np.concatenate(
- [
- np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
- np.array([[[239], [239]], [[239], [239]], [[239], [239]]]),
- ],
- axis=0,
- ),
-]
-
-
-TEST_CASE_RGB_0 = [np.ones((3, 2, 2), dtype=np.uint8)] # CHW
-
-TEST_CASE_RGB_1 = [np.ones((3, 100, 100), dtype=np.uint8)] # CHW
-
-TEST_CASE_ERROR_0C = [np.ones((16, 16), dtype=np.uint8)] # no color channel
-TEST_CASE_ERROR_1C = [np.ones((16, 16, 1), dtype=np.uint8)] # one color channel
-TEST_CASE_ERROR_2C = [np.ones((16, 16, 2), dtype=np.uint8)] # two color channels
-TEST_CASE_ERROR_3D = [np.ones((16, 16, 16, 3), dtype=np.uint8)] # 3D + color
-
-
-def save_rgba_tiff(array: np.ndarray, filename: str, mode: str):
- """
- Save numpy array into a TIFF RGB/RGBA file
-
- Args:
- array: numpy ndarray with the shape of CxHxW and C==3 representing a RGB image
- filename: the filename to be used for the tiff file. '_RGB.tiff' or '_RGBA.tiff' will be appended to this filename.
- mode: RGB or RGBA
- """
- if mode == "RGBA":
- array = np.concatenate([array, 255 * np.ones_like(array[0])[np.newaxis]]).astype(np.uint8)
-
- img_rgb = array.transpose(1, 2, 0)
- imwrite(filename, img_rgb, shape=img_rgb.shape, tile=(16, 16))
-
- return filename
-
-
-def save_gray_tiff(array: np.ndarray, filename: str):
- """
- Save numpy array into a TIFF file
-
- Args:
- array: numpy ndarray with any shape
- filename: the filename to be used for the tiff file.
- """
- img_gray = array
- imwrite(filename, img_gray, shape=img_gray.shape)
-
- return filename
-
-
-@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!")
-def setUpModule():
- hash_type = testing_data_config("images", FILE_KEY, "hash_type")
- hash_val = testing_data_config("images", FILE_KEY, "hash_val")
- download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val)
-
-
-class WSIReaderTests:
- class Tests(unittest.TestCase):
- backend = None
-
- @parameterized.expand([TEST_CASE_0])
- def test_read_whole_image(self, file_path, level, expected_shape):
- reader = WSIReader(self.backend, level=level)
- with reader.read(file_path) as img_obj:
- img, meta = reader.get_data(img_obj)
- self.assertTupleEqual(img.shape, expected_shape)
- self.assertEqual(meta["backend"], self.backend)
- self.assertEqual(meta["path"], str(os.path.abspath(file_path)))
- self.assertEqual(meta["patch_level"], level)
- assert_array_equal(meta["patch_size"], expected_shape[1:])
- assert_array_equal(meta["patch_location"], (0, 0))
-
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
- def test_read_region(self, file_path, kwargs, patch_info, expected_img):
- reader = WSIReader(self.backend, **kwargs)
- with reader.read(file_path) as img_obj:
- if self.backend == "tifffile":
- with self.assertRaises(ValueError):
- reader.get_data(img_obj, **patch_info)[0]
- else:
- # Read twice to check multiple calls
- img, meta = reader.get_data(img_obj, **patch_info)
- img2 = reader.get_data(img_obj, **patch_info)[0]
- self.assertTupleEqual(img.shape, img2.shape)
- self.assertIsNone(assert_array_equal(img, img2))
- self.assertTupleEqual(img.shape, expected_img.shape)
- self.assertIsNone(assert_array_equal(img, expected_img))
- self.assertEqual(meta["backend"], self.backend)
- self.assertEqual(meta["path"], str(os.path.abspath(file_path)))
- self.assertEqual(meta["patch_level"], patch_info["level"])
- assert_array_equal(meta["patch_size"], patch_info["size"])
- assert_array_equal(meta["patch_location"], patch_info["location"])
-
- @parameterized.expand([TEST_CASE_MULTI_WSI])
- def test_read_region_multi_wsi(self, file_path_list, patch_info, expected_img):
- kwargs = {"name": None, "offset": None} if self.backend == "tifffile" else {}
- reader = WSIReader(self.backend, **kwargs)
- img_obj_list = reader.read(file_path_list, **kwargs)
- if self.backend == "tifffile":
- with self.assertRaises(ValueError):
- reader.get_data(img_obj_list, **patch_info)[0]
- else:
- # Read twice to check multiple calls
- img, meta = reader.get_data(img_obj_list, **patch_info)
- img2 = reader.get_data(img_obj_list, **patch_info)[0]
- self.assertTupleEqual(img.shape, img2.shape)
- self.assertIsNone(assert_array_equal(img, img2))
- self.assertTupleEqual(img.shape, expected_img.shape)
- self.assertIsNone(assert_array_equal(img, expected_img))
- self.assertEqual(meta["backend"], self.backend)
- self.assertEqual(meta["path"][0], str(os.path.abspath(file_path_list[0])))
- self.assertEqual(meta["patch_level"][0], patch_info["level"])
- assert_array_equal(meta["patch_size"][0], expected_img.shape[1:])
- assert_array_equal(meta["patch_location"][0], patch_info["location"])
-
- @parameterized.expand([TEST_CASE_RGB_0, TEST_CASE_RGB_1])
- @skipUnless(has_tiff, "Requires tifffile.")
- def test_read_rgba(self, img_expected):
- # skip for OpenSlide since not working with images without tiles
- if self.backend == "openslide":
- return
- image = {}
- reader = WSIReader(self.backend)
- for mode in ["RGB", "RGBA"]:
- file_path = save_rgba_tiff(
- img_expected,
- os.path.join(os.path.dirname(__file__), "testing_data", f"temp_tiff_image_{mode}.tiff"),
- mode=mode,
- )
- with reader.read(file_path) as img_obj:
- image[mode], _ = reader.get_data(img_obj)
-
- self.assertIsNone(assert_array_equal(image["RGB"], img_expected))
- self.assertIsNone(assert_array_equal(image["RGBA"], img_expected))
-
- @parameterized.expand([TEST_CASE_ERROR_0C, TEST_CASE_ERROR_1C, TEST_CASE_ERROR_2C, TEST_CASE_ERROR_3D])
- @skipUnless(has_tiff, "Requires tifffile.")
- def test_read_malformats(self, img_expected):
- if self.backend == "cucim" and (len(img_expected.shape) < 3 or img_expected.shape[2] == 1):
- # Until cuCIM addresses https://github.com/rapidsai/cucim/issues/230
- return
- reader = WSIReader(self.backend)
- file_path = os.path.join(os.path.dirname(__file__), "testing_data", "temp_tiff_image_gray.tiff")
- imwrite(file_path, img_expected, shape=img_expected.shape)
- with self.assertRaises((RuntimeError, ValueError, openslide.OpenSlideError if has_osl else ValueError)):
- with reader.read(file_path) as img_obj:
- reader.get_data(img_obj)
-
- @parameterized.expand([TEST_CASE_TRANSFORM_0])
- def test_with_dataloader(self, file_path, level, expected_spatial_shape, expected_shape):
- train_transform = Compose(
- [
- LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
- FromMetaTensord(keys=["image"]),
- ToTensord(keys=["image"]),
- ]
- )
- dataset = Dataset([{"image": file_path}], transform=train_transform)
- data_loader = DataLoader(dataset)
- data: dict = first(data_loader)
- for s in data[PostFix.meta("image")]["spatial_shape"]:
- assert_allclose(s, expected_spatial_shape, type_test=False)
- self.assertTupleEqual(data["image"].shape, expected_shape)
-
- @parameterized.expand([TEST_CASE_TRANSFORM_0])
- def test_with_dataloader_batch(self, file_path, level, expected_spatial_shape, expected_shape):
- train_transform = Compose(
- [
- LoadImaged(keys=["image"], reader=WSIReader, backend=self.backend, level=level),
- FromMetaTensord(keys=["image"]),
- ToTensord(keys=["image"]),
- ]
- )
- dataset = Dataset([{"image": file_path}, {"image": file_path}], transform=train_transform)
- batch_size = 2
- data_loader = DataLoader(dataset, batch_size=batch_size)
- data: dict = first(data_loader)
- for s in data[PostFix.meta("image")]["spatial_shape"]:
- assert_allclose(s, expected_spatial_shape, type_test=False)
- self.assertTupleEqual(data["image"].shape, (batch_size, *expected_shape[1:]))
-
-
-@skipUnless(has_cucim, "Requires cucim")
-class TestCuCIM(WSIReaderTests.Tests):
- @classmethod
- def setUpClass(cls):
- cls.backend = "cucim"
-
-
-@skipUnless(has_osl, "Requires openslide")
-class TestOpenSlide(WSIReaderTests.Tests):
- @classmethod
- def setUpClass(cls):
- cls.backend = "openslide"
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/test_zoom.py b/tests/test_zoom.py
index 78beec69a1a..1d0447e319b 100644
--- a/tests/test_zoom.py
+++ b/tests/test_zoom.py
@@ -12,7 +12,6 @@
import unittest
import numpy as np
-import torch
from parameterized import parameterized
from scipy.ndimage import zoom as zoom_scipy
@@ -75,7 +74,7 @@ def test_padding_mode(self):
test_data = p([[[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]])
zoomed = zoom_fn(test_data)
expected = p([[[0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])
- torch.testing.assert_allclose(zoomed, expected)
+ assert_allclose(zoomed, expected, type_test=False)
if __name__ == "__main__":
diff --git a/tests/testing_data/anatomical_label.nii.gz b/tests/testing_data/anatomical_label.nii.gz
new file mode 100644
index 00000000000..a31ef9f7a42
Binary files /dev/null and b/tests/testing_data/anatomical_label.nii.gz differ
diff --git a/tests/testing_data/config_fl_evaluate.json b/tests/testing_data/config_fl_evaluate.json
new file mode 100644
index 00000000000..113596070aa
--- /dev/null
+++ b/tests/testing_data/config_fl_evaluate.json
@@ -0,0 +1,87 @@
+{
+ "bundle_root": "tests/testing_data",
+ "dataset_dir": "@bundle_root",
+ "imports": [
+ "$import os"
+ ],
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
+ "network_def": {
+ "_target_": "DenseNet121",
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 6
+ },
+ "network": "$@network_def.to(@device)",
+ "validate": {
+ "val_transforms": [
+ {
+ "_target_": "LoadImaged",
+ "keys": [
+ "image"
+ ],
+ "image_only": true
+ },
+ {
+ "_target_": "EnsureChannelFirstD",
+ "keys": [
+ "image"
+ ]
+ },
+ {
+ "_target_": "ScaleIntensityd",
+ "keys": [
+ "image"
+ ]
+ },
+ {
+ "_target_": "ToTensord",
+ "keys": [
+ "image",
+ "label"
+ ]
+ }
+ ],
+ "preprocessing": {
+ "_target_": "Compose",
+ "transforms": "$@validate#val_transforms"
+ },
+ "dataset": {
+ "_target_": "Dataset",
+ "data": [
+ {
+ "image": "$os.path.join(@dataset_dir, 'image0.jpeg')",
+ "label": 0
+ },
+ {
+ "image": "$os.path.join(@dataset_dir, 'image1.jpeg')",
+ "label": 1
+ }
+ ],
+ "transform": "@validate#preprocessing"
+ },
+ "dataloader": {
+ "_target_": "DataLoader",
+ "dataset": "@validate#dataset",
+ "batch_size": 3,
+ "shuffle": false,
+ "num_workers": 4
+ },
+ "inferer": {
+ "_target_": "SimpleInferer"
+ },
+ "key_metric": {
+ "accuracy": {
+ "_target_": "ignite.metrics.Accuracy",
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
+ }
+ },
+ "evaluator": {
+ "_target_": "SupervisedEvaluator",
+ "device": "@device",
+ "val_data_loader": "@validate#dataloader",
+ "network": "@network",
+ "inferer": "@validate#inferer",
+ "key_val_metric": "@validate#key_metric"
+ }
+ }
+}
diff --git a/tests/testing_data/config_fl_filters.json b/tests/testing_data/config_fl_filters.json
new file mode 100644
index 00000000000..5ccafa334c2
--- /dev/null
+++ b/tests/testing_data/config_fl_filters.json
@@ -0,0 +1,13 @@
+{
+ "pre_filters": [
+ {
+ "_target_": "monai.fl.utils.filters.SummaryFilter"
+ }
+ ],
+ "post_weight_filters": [
+ {
+ "_target_": "monai.fl.utils.filters.SummaryFilter"
+ }
+ ],
+ "post_evaluate_filters": []
+}
diff --git a/tests/testing_data/config_fl_stats_1.json b/tests/testing_data/config_fl_stats_1.json
new file mode 100644
index 00000000000..41b42eb3bb2
--- /dev/null
+++ b/tests/testing_data/config_fl_stats_1.json
@@ -0,0 +1,23 @@
+{
+ "imports": [
+ "$import os"
+ ],
+ "bundle_root": "tests/testing_data",
+ "dataset_dir": "@bundle_root",
+ "train": {
+ "dataset": {
+ "_target_": "Dataset",
+ "data": [
+ {
+ "image": "$os.path.join(@dataset_dir, 'anatomical.nii')",
+ "label": "$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')"
+ },
+ {
+ "image": "$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')",
+ "label": "$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')"
+ }
+ ],
+ "transform": "@train#preprocessing"
+ }
+ }
+}
diff --git a/tests/testing_data/config_fl_stats_2.json b/tests/testing_data/config_fl_stats_2.json
new file mode 100644
index 00000000000..bf55673f67a
--- /dev/null
+++ b/tests/testing_data/config_fl_stats_2.json
@@ -0,0 +1,39 @@
+{
+ "imports": [
+ "$import os"
+ ],
+ "bundle_root": "tests/testing_data",
+ "dataset_dir": "@bundle_root",
+ "train": {
+ "dataset": {
+ "_target_": "Dataset",
+ "data": [
+ {
+ "image": "$os.path.join(@dataset_dir, 'anatomical.nii')",
+ "label": "$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')"
+ },
+ {
+ "image": "$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')",
+ "label": "$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')"
+ }
+ ],
+ "transform": "@train#preprocessing"
+ }
+ },
+ "validate": {
+ "dataset": {
+ "_target_": "Dataset",
+ "data": [
+ {
+ "image": "$os.path.join(@dataset_dir, 'anatomical.nii')",
+ "label": "$os.path.join(@dataset_dir, 'anatomical_label.nii.gz')"
+ },
+ {
+ "image": "$os.path.join(@dataset_dir, 'reoriented_anat_moved.nii')",
+ "label": "$os.path.join(@dataset_dir, 'reoriented_anat_moved_label.nii.gz')"
+ }
+ ],
+ "transform": "@train#preprocessing"
+ }
+ }
+}
diff --git a/tests/testing_data/config_fl_train.json b/tests/testing_data/config_fl_train.json
new file mode 100644
index 00000000000..f53a95bc02d
--- /dev/null
+++ b/tests/testing_data/config_fl_train.json
@@ -0,0 +1,125 @@
+{
+ "bundle_root": "tests/testing_data",
+ "dataset_dir": "@bundle_root",
+ "imports": [
+ "$import os"
+ ],
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
+ "network_def": {
+ "_target_": "DenseNet121",
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 6
+ },
+ "network": "$@network_def.to(@device)",
+ "loss": {
+ "_target_": "torch.nn.CrossEntropyLoss"
+ },
+ "optimizer": {
+ "_target_": "torch.optim.Adam",
+ "params": "$@network.parameters()",
+ "lr": 0.0001
+ },
+ "train": {
+ "training_transforms": [
+ {
+ "_target_": "LoadImaged",
+ "keys": [
+ "image"
+ ],
+ "image_only": true
+ },
+ {
+ "_target_": "EnsureChannelFirstD",
+ "keys": [
+ "image"
+ ]
+ },
+ {
+ "_target_": "ScaleIntensityd",
+ "keys": [
+ "image"
+ ]
+ },
+ {
+ "_target_": "RandRotated",
+ "keys": [
+ "image"
+ ],
+ "range_x": 15,
+ "prob": 0.5,
+ "keep_size": true
+ },
+ {
+ "_target_": "RandFlipd",
+ "keys": [
+ "image"
+ ],
+ "spatial_axis": 0,
+ "prob": 0.5
+ },
+ {
+ "_target_": "RandZoomd",
+ "keys": [
+ "image"
+ ],
+ "min_zoom": 0.9,
+ "max_zoom": 1.1,
+ "prob": 0.5
+ },
+ {
+ "_target_": "ToTensord",
+ "keys": [
+ "image",
+ "label"
+ ]
+ }
+ ],
+ "preprocessing": {
+ "_target_": "Compose",
+ "transforms": "$@train#training_transforms"
+ },
+ "dataset": {
+ "_target_": "Dataset",
+ "data": [
+ {
+ "image": "$os.path.join(@dataset_dir, 'image0.jpeg')",
+ "label": 0
+ },
+ {
+ "image": "$os.path.join(@dataset_dir, 'image1.jpeg')",
+ "label": 1
+ }
+ ],
+ "transform": "@train#preprocessing"
+ },
+ "dataloader": {
+ "_target_": "DataLoader",
+ "dataset": "@train#dataset",
+ "batch_size": 3,
+ "shuffle": true,
+ "num_workers": 2
+ },
+ "inferer": {
+ "_target_": "SimpleInferer"
+ },
+ "handlers": [
+ {
+ "_target_": "StatsHandler",
+ "tag_name": "train_loss",
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
+ }
+ ],
+ "trainer": {
+ "_target_": "SupervisedTrainer",
+ "max_epochs": 2,
+ "device": "@device",
+ "train_data_loader": "@train#dataloader",
+ "network": "@network",
+ "loss_function": "@loss",
+ "optimizer": "@optimizer",
+ "inferer": "@train#inferer",
+ "train_handlers": "@train#handlers"
+ }
+ }
+}
diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json
index 254314d1b87..788d6644397 100644
--- a/tests/testing_data/data_config.json
+++ b/tests/testing_data/data_config.json
@@ -26,12 +26,12 @@
"hash_val": "a14231f539c0f365a5f83f2a046969a9b9870e56ffd126fd8e7242364d25938a"
},
"0000_t2_tse_tra_4": {
- "url": "https://github.com/rcuocolo/PROSTATEx_masks/raw/master/Files/lesions/Images/T2/ProstateX-0000_t2_tse_tra_4.nii.gz",
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_t2_tse_tra_4.nii.gz",
"hash_type": "md5",
"hash_val": "adb3f1c4db66a6481c3e4a2a3033c7d5"
},
"0000_ep2d_diff_tra_7": {
- "url": "https://github.com/rcuocolo/PROSTATEx_masks/raw/master/Files/lesions/Images/ADC/ProstateX-0000_ep2d_diff_tra_7.nii.gz",
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ProstateX-0000_ep2d_diff_tra_7.nii.gz",
"hash_type": "md5",
"hash_val": "f12a11ad0ebb0b1876e9e010564745d2"
},
@@ -56,6 +56,18 @@
"hash_val": "eb4f1e596ca85aadaefc359d409fb9a3e27d733e6def04b996953b7c54bc26d4"
}
},
+ "videos": {
+ "endovis": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/d1_im.mp4",
+ "hash_type": "md5",
+ "hash_val": "9b103c07326439b0ea376018d7189384"
+ },
+ "ultrasound": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/example_data_Ultrasound_Q000_04_tu_segmented_ultrasound_256.avi",
+ "hash_type": "md5",
+ "hash_val": "f0755960cc4a08a958561cda9a79a157"
+ }
+ },
"models": {
"senet154-c7b49a05": {
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/senet154-c7b49a05.pth",
diff --git a/tests/testing_data/image0.jpeg b/tests/testing_data/image0.jpeg
new file mode 100644
index 00000000000..5025e5f9928
Binary files /dev/null and b/tests/testing_data/image0.jpeg differ
diff --git a/tests/testing_data/image1.jpeg b/tests/testing_data/image1.jpeg
new file mode 100644
index 00000000000..462db4bbc26
Binary files /dev/null and b/tests/testing_data/image1.jpeg differ
diff --git a/tests/testing_data/integration_answers.py b/tests/testing_data/integration_answers.py
index 99765a2b338..1907e3c4a54 100644
--- a/tests/testing_data/integration_answers.py
+++ b/tests/testing_data/integration_answers.py
@@ -433,6 +433,64 @@
"infer_metric": 0.9326590299606323,
},
},
+ { # test answers for PyTorch 1.13
+ "integration_workflows": {
+ "output_sums_2": [
+ 0.14264830205979873,
+ 0.15264129328718357,
+ 0.1519652511118344,
+ 0.14003114557361543,
+ 0.18870416611118465,
+ 0.1699260498246968,
+ 0.14727475398203582,
+ 0.16870874483246967,
+ 0.15757932277023196,
+ 0.1797779694564011,
+ 0.16310501082450635,
+ 0.16850569170136015,
+ 0.14472958359864832,
+ 0.11402527744419455,
+ 0.16217657428257873,
+ 0.20135486560244975,
+ 0.17627557567092866,
+ 0.09802074024435596,
+ 0.19418729084978026,
+ 0.20339278025379662,
+ 0.1966174446916041,
+ 0.20872528599049203,
+ 0.16246183433492764,
+ 0.1323750751202327,
+ 0.14830347036335728,
+ 0.14300732028781024,
+ 0.23163101813922762,
+ 0.1612925258625139,
+ 0.1489573676973957,
+ 0.10299491921717041,
+ 0.11921404797064328,
+ 0.1300212751422368,
+ 0.11437829790254125,
+ 0.1524755276727056,
+ 0.16350584736767904,
+ 0.19424317961257148,
+ 0.2229762916892286,
+ 0.18121074825540173,
+ 0.19064286213535897,
+ 0.0747544243069024,
+ ]
+ },
+ "integration_segmentation_3d": { # for the mixed readers
+ "losses": [
+ 0.5451162219047546,
+ 0.4709601759910583,
+ 0.45201429128646853,
+ 0.4443251401185989,
+ 0.4341257899999619,
+ 0.4350819975137711,
+ ],
+ "best_metric": 0.9316844940185547,
+ "infer_metric": 0.9316383600234985,
+ },
+ },
{ # test answers for PyTorch 21.10
"integration_classification_2d": {
"losses": [0.7806222991199251, 0.16259610306495315, 0.07529311385124353, 0.04640352608529246],
diff --git a/tests/testing_data/multi_gpu_evaluate.json b/tests/testing_data/multi_gpu_evaluate.json
new file mode 100644
index 00000000000..7af24a6b2e1
--- /dev/null
+++ b/tests/testing_data/multi_gpu_evaluate.json
@@ -0,0 +1,27 @@
+{
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
+ "network": {
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
+ "module": "$@network_def.to(@device)",
+ "device_ids": [
+ "@device"
+ ]
+ },
+ "validate#sampler": {
+ "_target_": "DistributedSampler",
+ "dataset": "@validate#dataset",
+ "even_divisible": false,
+ "shuffle": false
+ },
+ "validate#dataloader#sampler": "@validate#sampler",
+ "evaluating": [
+ "$import torch.distributed as dist",
+ "$dist.init_process_group(backend='nccl')",
+ "$torch.cuda.set_device(@device)",
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
+ "$import logging",
+ "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
+ "$@validate#evaluator.run()",
+ "$dist.destroy_process_group()"
+ ]
+}
diff --git a/tests/testing_data/multi_gpu_train.json b/tests/testing_data/multi_gpu_train.json
new file mode 100644
index 00000000000..41fd7698db5
--- /dev/null
+++ b/tests/testing_data/multi_gpu_train.json
@@ -0,0 +1,30 @@
+{
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
+ "network": {
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
+ "module": "$@network_def.to(@device)",
+ "device_ids": [
+ "@device"
+ ]
+ },
+ "train#sampler": {
+ "_target_": "DistributedSampler",
+ "dataset": "@train#dataset",
+ "even_divisible": true,
+ "shuffle": true
+ },
+ "train#dataloader#sampler": "@train#sampler",
+ "train#dataloader#shuffle": false,
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
+ "training": [
+ "$import torch.distributed as dist",
+ "$dist.init_process_group(backend='nccl')",
+ "$torch.cuda.set_device(@device)",
+ "$monai.utils.set_determinism(seed=123)",
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
+ "$import logging",
+ "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
+ "$@train#trainer.run()",
+ "$dist.destroy_process_group()"
+ ]
+}
diff --git a/tests/testing_data/reoriented_anat_moved_label.nii.gz b/tests/testing_data/reoriented_anat_moved_label.nii.gz
new file mode 100644
index 00000000000..2d148d19995
Binary files /dev/null and b/tests/testing_data/reoriented_anat_moved_label.nii.gz differ
diff --git a/tests/testing_data/signal.npy b/tests/testing_data/signal.npy
new file mode 100755
index 00000000000..7803f0d371d
Binary files /dev/null and b/tests/testing_data/signal.npy differ
diff --git a/tests/testing_data/transform_metatensor_cases.yaml b/tests/testing_data/transform_metatensor_cases.yaml
index 75275b9df51..a639a133091 100644
--- a/tests/testing_data/transform_metatensor_cases.yaml
+++ b/tests/testing_data/transform_metatensor_cases.yaml
@@ -180,7 +180,7 @@ TEST_CASE_3:
TEST_CASE_1_answer:
load_shape: [1, 1, 33, 45, 54]
- affine: "$np.array([[-2, 0, 0, 34], [0, 2, 0, -64], [0, 0, 2, -54], [0, 0, 0, 1]], dtype=np.float64)"
+ affine: "$np.array([[-2, 0, 0, 30], [0, 2, 0, -62], [0, 0, 2, -48], [0, 0, 0, 1]], dtype=np.float64)"
inv_affine: "@init_affine"
inv_shape: "@init_shape"
@@ -192,6 +192,6 @@ TEST_CASE_2_answer:
TEST_CASE_3_answer:
load_shape: [1, 1, 72, 57, 82]
- affine: "$np.array([[-1.343816, -0.682904, -0.234832, 76.01494], [0.309004, 0.653211, -1.734872, 24.511358], [-0.104049, 1.617199, 0.584171, -56.521294], [0, 0, 0, 1]], dtype=np.float64)"
+ affine: "$np.array([[1.300558, -0.700765, -0.511861, -3.739605], [0.479723, -1.171149, 1.193079, -50.087933], [0.395736, 1.183532, 0.984201, -80.496605], [0, 0, 0, 1]], dtype=np.float64)"
inv_affine: "@init_affine"
inv_shape: "@init_shape"
diff --git a/tests/utils.py b/tests/utils.py
index c101196830a..afe08e0bfae 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -18,6 +18,7 @@
import os
import queue
import ssl
+import subprocess
import sys
import tempfile
import time
@@ -46,6 +47,7 @@
from monai.utils.type_conversion import convert_data_type
nib, _ = optional_import("nibabel")
+http_error, has_requests = optional_import("requests", name="HTTPError")
quick_test_var = "QUICKTEST"
_tf32_enabled = None
@@ -122,7 +124,7 @@ def assert_allclose(
def skip_if_downloading_fails():
try:
yield
- except (ContentTooShortError, HTTPError, ConnectionError) as e:
+ except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e:
raise unittest.SkipTest(f"error while downloading: {e}") from e
except ssl.SSLError as ssl_e:
if "decryption failed" in str(ssl_e):
@@ -136,6 +138,8 @@ def skip_if_downloading_fails():
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
if "md5 check" in str(rt_e):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
+ if "limit" in str(rt_e): # HTTP Error 503: Egress is over the account limit
+ raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e
raise rt_e
@@ -403,6 +407,7 @@ def __init__(
timeout: Timeout for operations executed against the process group.
init_method: URL specifying how to initialize the process group.
Default is "env://" or "file:///d:/a_temp" (windows) if unspecified.
+ If ``"no_init"``, the `dist.init_process_group` must be called within the code to be tested.
backend: The backend to use. Depending on build-time configurations,
valid values include ``mpi``, ``gloo``, and ``nccl``.
daemon: the process’s daemon flag.
@@ -450,13 +455,14 @@ def run_process(self, func, local_rank, args, kwargs, results):
if torch.cuda.is_available():
torch.cuda.set_device(int(local_rank)) # using device ids from CUDA_VISIBILE_DEVICES
- dist.init_process_group(
- backend=self.backend,
- init_method=self.init_method,
- timeout=self.timeout,
- world_size=int(os.environ["WORLD_SIZE"]),
- rank=int(os.environ["RANK"]),
- )
+ if self.init_method != "no_init":
+ dist.init_process_group(
+ backend=self.backend,
+ init_method=self.init_method,
+ timeout=self.timeout,
+ world_size=int(os.environ["WORLD_SIZE"]),
+ rank=int(os.environ["RANK"]),
+ )
func(*args, **kwargs)
# the primary node lives longer to
# avoid _store_based_barrier, RuntimeError: Broken pipe
@@ -692,16 +698,21 @@ def test_script_save(net, *inputs, device=None, rtol=1e-4, atol=0.0):
"""
# TODO: would be nice to use GPU if available, but it currently causes CI failures.
device = "cpu"
- with tempfile.TemporaryDirectory() as tempdir:
- convert_to_torchscript(
- model=net,
- filename_or_obj=os.path.join(tempdir, "model.ts"),
- verify=True,
- inputs=inputs,
- device=device,
- rtol=rtol,
- atol=atol,
- )
+ try:
+ with tempfile.TemporaryDirectory() as tempdir:
+ convert_to_torchscript(
+ model=net,
+ filename_or_obj=os.path.join(tempdir, "model.ts"),
+ verify=True,
+ inputs=inputs,
+ device=device,
+ rtol=rtol,
+ atol=atol,
+ )
+ except (RuntimeError, AttributeError):
+ if sys.version_info.major == 3 and sys.version_info.minor == 11:
+ warnings.warn("skipping py 3.11")
+ return
def download_url_or_skip_test(*args, **kwargs):
@@ -743,6 +754,18 @@ def test_local_inversion(invertible_xform, to_invert, im, dict_key=None):
assert_allclose(im_inv.affine, im_ref.affine, atol=1e-3, rtol=1e-3)
+def command_line_tests(cmd, copy_env=True):
+ test_env = os.environ.copy() if copy_env else os.environ
+ print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
+ try:
+ normal_out = subprocess.run(cmd, env=test_env, check=True, capture_output=True)
+ print(repr(normal_out).replace("\\n", "\n").replace("\\t", "\t"))
+ except subprocess.CalledProcessError as e:
+ output = repr(e.stdout).replace("\\n", "\n").replace("\\t", "\t")
+ errors = repr(e.stderr).replace("\\n", "\n").replace("\\t", "\t")
+ raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e
+
+
TEST_TORCH_TENSORS: Tuple = (torch.as_tensor,)
if torch.cuda.is_available():
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")
@@ -758,11 +781,10 @@ def test_local_inversion(invertible_xform, to_invert, im, dict_key=None):
# alias for branch tests
TEST_NDARRAYS_ALL = TEST_NDARRAYS
-
TEST_DEVICES = [[torch.device("cpu")]]
if torch.cuda.is_available():
TEST_DEVICES.append([torch.device("cuda")])
-
if __name__ == "__main__":
- print(query_memory())
+ print("\n", query_memory(), sep="\n") # print to stdout
+ sys.exit(0)
diff --git a/versioneer.py b/versioneer.py
index 9112ac66a5e..a06587fc3fc 100644
--- a/versioneer.py
+++ b/versioneer.py
@@ -1,4 +1,4 @@
-# Version: 0.19
+# Version: 0.23
"""The Versioneer - like a rocketeer, but for versions.
@@ -8,12 +8,12 @@
* like a rocketeer, but for versions!
* https://github.com/python-versioneer/python-versioneer
* Brian Warner
-* License: Public Domain
-* Compatible with: Python 3.6, 3.7, 3.8, 3.9 and pypy3
+* License: Public Domain (CC0-1.0)
+* Compatible with: Python 3.7, 3.8, 3.9, 3.10 and pypy3
* [![Latest Version][pypi-image]][pypi-url]
* [![Build Status][travis-image]][travis-url]
-This is a tool for managing a recorded version number in distutils-based
+This is a tool for managing a recorded version number in distutils/setuptools-based
python projects. The goal is to remove the tedious and error-prone "update
the embedded version string" step from your release process. Making a new
release should be as easy as recording a new tag in your version-control
@@ -255,6 +255,8 @@
dependency
* [minver](https://github.com/jbweston/miniver) - a lightweight reimplementation of
versioneer
+* [versioningit](https://github.com/jwodder/versioningit) - a PEP 518-based setuptools
+ plugin
## License
@@ -271,6 +273,11 @@
[travis-url]: https://travis-ci.com/github/python-versioneer/python-versioneer
"""
+# pylint:disable=invalid-name,import-outside-toplevel,missing-function-docstring
+# pylint:disable=missing-class-docstring,too-many-branches,too-many-statements
+# pylint:disable=raise-missing-from,too-many-lines,too-many-locals,import-error
+# pylint:disable=too-few-public-methods,redefined-outer-name,consider-using-with
+# pylint:disable=attribute-defined-outside-init,too-many-arguments
import configparser
import errno
@@ -279,6 +286,8 @@
import re
import subprocess
import sys
+from typing import Callable, Dict
+import functools
class VersioneerConfig:
@@ -315,11 +324,11 @@ def get_root():
# module-import table will cache the first one. So we can't use
# os.path.dirname(__file__), as that will find whichever
# versioneer.py was first imported, even in later projects.
- me = os.path.realpath(os.path.abspath(__file__))
- me_dir = os.path.normcase(os.path.splitext(me)[0])
+ my_path = os.path.realpath(os.path.abspath(__file__))
+ me_dir = os.path.normcase(os.path.splitext(my_path)[0])
vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])
if me_dir != vsr_dir:
- print("Warning: build in %s is using versioneer.py from %s" % (os.path.dirname(me), versioneer_py))
+ print("Warning: build in %s is using versioneer.py from %s" % (os.path.dirname(my_path), versioneer_py))
except NameError:
pass
return root
@@ -327,31 +336,29 @@ def get_root():
def get_config_from_root(root):
"""Read the project setup.cfg file to determine Versioneer config."""
- # This might raise EnvironmentError (if setup.cfg is missing), or
+ # This might raise OSError (if setup.cfg is missing), or
# configparser.NoSectionError (if it lacks a [versioneer] section), or
# configparser.NoOptionError (if it lacks "VCS="). See the docstring at
# the top of versioneer.py for instructions on writing your setup.cfg .
setup_cfg = os.path.join(root, "setup.cfg")
parser = configparser.ConfigParser()
- with open(setup_cfg, "r") as f:
- parser.read_file(f)
+ with open(setup_cfg, "r") as cfg_file:
+ parser.read_file(cfg_file)
VCS = parser.get("versioneer", "VCS") # mandatory
- def get(parser, name):
- if parser.has_option("versioneer", name):
- return parser.get("versioneer", name)
- return None
+ # Dict-like interface for non-mandatory entries
+ section = parser["versioneer"]
cfg = VersioneerConfig()
cfg.VCS = VCS
- cfg.style = get(parser, "style") or ""
- cfg.versionfile_source = get(parser, "versionfile_source")
- cfg.versionfile_build = get(parser, "versionfile_build")
- cfg.tag_prefix = get(parser, "tag_prefix")
- if cfg.tag_prefix in ("''", '""'):
+ cfg.style = section.get("style", "")
+ cfg.versionfile_source = section.get("versionfile_source")
+ cfg.versionfile_build = section.get("versionfile_build")
+ cfg.tag_prefix = section.get("tag_prefix")
+ if cfg.tag_prefix in ("''", '""', None):
cfg.tag_prefix = ""
- cfg.parentdir_prefix = get(parser, "parentdir_prefix")
- cfg.verbose = get(parser, "verbose")
+ cfg.parentdir_prefix = section.get("parentdir_prefix")
+ cfg.verbose = section.get("verbose")
return cfg
@@ -360,8 +367,8 @@ class NotThisMethod(Exception):
# these dictionaries contain VCS-specific tools
-LONG_VERSION_PY = {}
-HANDLERS = {}
+LONG_VERSION_PY: Dict[str, str] = {}
+HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
@@ -369,9 +376,7 @@ def register_vcs_handler(vcs, method): # decorator
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
- if vcs not in HANDLERS:
- HANDLERS[vcs] = {}
- HANDLERS[vcs][method] = f
+ HANDLERS.setdefault(vcs, {})[method] = f
return f
return decorate
@@ -380,16 +385,29 @@ def decorate(f):
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
- p = None
- for c in commands:
+ process = None
+
+ popen_kwargs = {}
+ if sys.platform == "win32":
+ # This hides the console window if pythonw.exe is used
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+ popen_kwargs["startupinfo"] = startupinfo
+
+ for command in commands:
try:
- dispcmd = str([c] + args)
+ dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen(
- [c] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None)
+ process = subprocess.Popen(
+ [command] + args,
+ cwd=cwd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr else None),
+ **popen_kwargs,
)
break
- except EnvironmentError:
+ except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
@@ -401,13 +419,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
- stdout = p.communicate()[0].strip().decode()
- if p.returncode != 0:
+ stdout = process.communicate()[0].strip().decode()
+ if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
- return None, p.returncode
- return stdout, p.returncode
+ return None, process.returncode
+ return stdout, process.returncode
LONG_VERSION_PY[
@@ -420,7 +438,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
# that just contains the computed version number.
# This file is released into the public domain. Generated by
-# versioneer-0.19 (https://github.com/python-versioneer/python-versioneer)
+# versioneer-0.23 (https://github.com/python-versioneer/python-versioneer)
"""Git implementation of _version.py."""
@@ -429,6 +447,8 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
import re
import subprocess
import sys
+from typing import Callable, Dict
+import functools
def get_keywords():
@@ -466,8 +486,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
-LONG_VERSION_PY = {}
-HANDLERS = {}
+LONG_VERSION_PY: Dict[str, str] = {}
+HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
@@ -485,17 +505,25 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
- p = None
- for c in commands:
+ process = None
+
+ popen_kwargs = {}
+ if sys.platform == "win32":
+ # This hides the console window if pythonw.exe is used
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+ popen_kwargs["startupinfo"] = startupinfo
+
+ for command in commands:
try:
- dispcmd = str([c] + args)
+ dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
- p = subprocess.Popen([c] + args, cwd=cwd, env=env,
- stdout=subprocess.PIPE,
- stderr=(subprocess.PIPE if hide_stderr
- else None))
+ process = subprocess.Popen([command] + args, cwd=cwd, env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr
+ else None), **popen_kwargs)
break
- except EnvironmentError:
+ except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
@@ -507,13 +535,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
if verbose:
print("unable to find command, tried %%s" %% (commands,))
return None, None
- stdout = p.communicate()[0].strip().decode()
- if p.returncode != 0:
+ stdout = process.communicate()[0].strip().decode()
+ if process.returncode != 0:
if verbose:
print("unable to run %%s (error)" %% dispcmd)
print("stdout was %%s" %% stdout)
- return None, p.returncode
- return stdout, p.returncode
+ return None, process.returncode
+ return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
@@ -525,15 +553,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
"""
rootdirs = []
- for i in range(3):
+ for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
- else:
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
+ rootdirs.append(root)
+ root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %%s but none started with prefix %%s" %%
@@ -550,22 +577,21 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords = {}
try:
- f = open(versionfile_abs, "r")
- for line in f.readlines():
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- f.close()
- except EnvironmentError:
+ with open(versionfile_abs, "r") as fobj:
+ for line in fobj:
+ if line.strip().startswith("git_refnames ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["refnames"] = mo.group(1)
+ if line.strip().startswith("git_full ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["full"] = mo.group(1)
+ if line.strip().startswith("git_date ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["date"] = mo.group(1)
+ except OSError:
pass
return keywords
@@ -573,8 +599,8 @@ def git_get_keywords(versionfile_abs):
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
- if not keywords:
- raise NotThisMethod("no keywords at all, weird")
+ if "refnames" not in keywords:
+ raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
@@ -593,11 +619,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = set([r.strip() for r in refnames.strip("()").split(",")])
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
+ tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %%d
@@ -606,7 +632,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r'\d', r)])
+ tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%%s', no digits" %% ",".join(refs - tags))
if verbose:
@@ -615,6 +641,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
+ # Filter out refs that exactly match prefix or that don't start
+ # with a number once the prefix is stripped (mostly a concern
+ # when prefix is '')
+ if not re.match(r'\d', r):
+ continue
if verbose:
print("picking %%s" %% r)
return {"version": r,
@@ -630,7 +661,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
+def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
@@ -641,8 +672,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
- hide_stderr=True)
+ # GIT_DIR can interfere with correct operation of Versioneer.
+ # It may be intended to be passed to the Versioneer-versioned project,
+ # but that should not change where we get our version from.
+ env = os.environ.copy()
+ env.pop("GIT_DIR", None)
+ runner = functools.partial(runner, env=env)
+
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
+ hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %%s not under git control" %% root)
@@ -650,15 +688,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
- "--always", "--long",
- "--match", "%%s*" %% tag_prefix],
- cwd=root)
+ describe_out, rc = runner(GITS, [
+ "describe", "--tags", "--dirty", "--always", "--long",
+ "--match", f"{tag_prefix}[[:digit:]]*"
+ ], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
- full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
@@ -668,6 +706,39 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
+ cwd=root)
+ # --abbrev-ref was added in git-1.6.3
+ if rc != 0 or branch_name is None:
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
+ branch_name = branch_name.strip()
+
+ if branch_name == "HEAD":
+ # If we aren't exactly on a branch, pick a branch which represents
+ # the current commit. If all else fails, we are on a branchless
+ # commit.
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
+ # --contains was added in git-1.5.4
+ if rc != 0 or branches is None:
+ raise NotThisMethod("'git branch --contains' returned error")
+ branches = branches.split("\n")
+
+ # Remove the first line if we're running detached
+ if "(" in branches[0]:
+ branches.pop(0)
+
+ # Strip off the leading "* " from the list of branches.
+ branches = [branch[2:] for branch in branches]
+ if "master" in branches:
+ branch_name = "master"
+ elif not branches:
+ branch_name = None
+ else:
+ # Pick the first branch that is returned. Good or bad.
+ branch_name = branches[0]
+
+ pieces["branch"] = branch_name
+
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
@@ -684,7 +755,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
- # unparseable. Maybe git-describe is misbehaving?
+ # unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%%s'"
%% describe_out)
return pieces
@@ -709,13 +780,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
- cwd=root)
- pieces["distance"] = int(count_out) # total number of commits
+ out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
+ pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"],
- cwd=root)[0].strip()
+ date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
@@ -756,16 +825,64 @@ def render_pep440(pieces):
return rendered
+def render_pep440_branch(pieces):
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
+
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
+ (a feature branch will appear "older" than the master branch).
+
+ Exceptions:
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0"
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+untagged.%%d.g%%s" %% (pieces["distance"],
+ pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def pep440_split_post(ver):
+ """Split pep440 version string at the post-release segment.
+
+ Returns the release segments before the post-release and the
+ post-release version number (or -1 if no post-release segment is present).
+ """
+ vc = str.split(ver, ".post")
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
+
+
def render_pep440_pre(pieces):
- """TAG[.post0.devDISTANCE] -- No -dirty.
+ """TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
if pieces["distance"]:
- rendered += ".post0.dev%%d" %% pieces["distance"]
+ # update the post release segment
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
+ rendered = tag_version
+ if post_version is not None:
+ rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"])
+ else:
+ rendered += ".post0.dev%%d" %% (pieces["distance"])
+ else:
+ # no commits, use the tag as the version
+ rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%%d" %% pieces["distance"]
@@ -799,6 +916,35 @@ def render_pep440_post(pieces):
return rendered
+def render_pep440_post_branch(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
+
+ The ".dev0" means not master branch.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%%d" %% pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%%s" %% pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0.post%%d" %% pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+g%%s" %% pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
@@ -875,10 +1021,14 @@ def render(pieces, style):
if style == "pep440":
rendered = render_pep440(pieces)
+ elif style == "pep440-branch":
+ rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
+ elif style == "pep440-post-branch":
+ rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
@@ -914,7 +1064,7 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
- for i in cfg.versionfile_source.split('/'):
+ for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
@@ -949,22 +1099,21 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords = {}
try:
- f = open(versionfile_abs, "r")
- for line in f.readlines():
- if line.strip().startswith("git_refnames ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["refnames"] = mo.group(1)
- if line.strip().startswith("git_full ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["full"] = mo.group(1)
- if line.strip().startswith("git_date ="):
- mo = re.search(r'=\s*"(.*)"', line)
- if mo:
- keywords["date"] = mo.group(1)
- f.close()
- except EnvironmentError:
+ with open(versionfile_abs, "r") as fobj:
+ for line in fobj:
+ if line.strip().startswith("git_refnames ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["refnames"] = mo.group(1)
+ if line.strip().startswith("git_full ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["full"] = mo.group(1)
+ if line.strip().startswith("git_date ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["date"] = mo.group(1)
+ except OSError:
pass
return keywords
@@ -972,8 +1121,8 @@ def git_get_keywords(versionfile_abs):
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
- if not keywords:
- raise NotThisMethod("no keywords at all, weird")
+ if "refnames" not in keywords:
+ raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
@@ -992,11 +1141,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
- refs = set([r.strip() for r in refnames.strip("()").split(",")])
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
- tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
+ tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
@@ -1005,7 +1154,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
- tags = set([r for r in refs if re.search(r"\d", r)])
+ tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
@@ -1014,6 +1163,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix) :]
+ # Filter out refs that exactly match prefix or that don't start
+ # with a number once the prefix is stripped (mostly a concern
+ # when prefix is '')
+ if not re.match(r"\d", r):
+ continue
if verbose:
print("picking %s" % r)
return {
@@ -1036,7 +1190,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
@register_vcs_handler("git", "pieces_from_vcs")
-def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
+def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
@@ -1047,7 +1201,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
+ # GIT_DIR can interfere with correct operation of Versioneer.
+ # It may be intended to be passed to the Versioneer-versioned project,
+ # but that should not change where we get our version from.
+ env = os.environ.copy()
+ env.pop("GIT_DIR", None)
+ runner = functools.partial(runner, env=env)
+
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
@@ -1055,14 +1216,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
- describe_out, rc = run_command(
- GITS, ["describe", "--tags", "--dirty", "--always", "--long", "--match", "%s*" % tag_prefix], cwd=root
+ describe_out, rc = runner(
+ GITS, ["describe", "--tags", "--dirty", "--always", "--long", "--match", f"{tag_prefix}[[:digit:]]*"], cwd=root
)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
- full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
@@ -1072,6 +1233,38 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
+ # --abbrev-ref was added in git-1.6.3
+ if rc != 0 or branch_name is None:
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
+ branch_name = branch_name.strip()
+
+ if branch_name == "HEAD":
+ # If we aren't exactly on a branch, pick a branch which represents
+ # the current commit. If all else fails, we are on a branchless
+ # commit.
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
+ # --contains was added in git-1.5.4
+ if rc != 0 or branches is None:
+ raise NotThisMethod("'git branch --contains' returned error")
+ branches = branches.split("\n")
+
+ # Remove the first line if we're running detached
+ if "(" in branches[0]:
+ branches.pop(0)
+
+ # Strip off the leading "* " from the list of branches.
+ branches = [branch[2:] for branch in branches]
+ if "master" in branches:
+ branch_name = "master"
+ elif not branches:
+ branch_name = None
+ else:
+ # Pick the first branch that is returned. Good or bad.
+ branch_name = branches[0]
+
+ pieces["branch"] = branch_name
+
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
@@ -1088,7 +1281,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
- # unparseable. Maybe git-describe is misbehaving?
+ # unparsable. Maybe git-describe is misbehaving?
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
@@ -1111,11 +1304,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
- count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
- pieces["distance"] = int(count_out) # total number of commits
+ out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
+ pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
- date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
@@ -1124,7 +1317,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
return pieces
-def do_vcs_install(manifest_in, versionfile_source, ipy):
+def do_vcs_install(versionfile_source, ipy):
"""Git-specific installation logic for Versioneer.
For Git, this means creating/changing .gitattributes to mark _version.py
@@ -1133,31 +1326,30 @@ def do_vcs_install(manifest_in, versionfile_source, ipy):
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
- files = [manifest_in, versionfile_source]
+ files = [versionfile_source]
if ipy:
files.append(ipy)
try:
- me = __file__
- if me.endswith(".pyc") or me.endswith(".pyo"):
- me = os.path.splitext(me)[0] + ".py"
- versioneer_file = os.path.relpath(me)
+ my_path = __file__
+ if my_path.endswith(".pyc") or my_path.endswith(".pyo"):
+ my_path = os.path.splitext(my_path)[0] + ".py"
+ versioneer_file = os.path.relpath(my_path)
except NameError:
versioneer_file = "versioneer.py"
files.append(versioneer_file)
present = False
try:
- f = open(".gitattributes", "r")
- for line in f.readlines():
- if line.strip().startswith(versionfile_source):
- if "export-subst" in line.strip().split()[1:]:
- present = True
- f.close()
- except EnvironmentError:
+ with open(".gitattributes", "r") as fobj:
+ for line in fobj:
+ if line.strip().startswith(versionfile_source):
+ if "export-subst" in line.strip().split()[1:]:
+ present = True
+ break
+ except OSError:
pass
if not present:
- f = open(".gitattributes", "a+")
- f.write("%s export-subst\n" % versionfile_source)
- f.close()
+ with open(".gitattributes", "a+") as fobj:
+ fobj.write(f"{versionfile_source} export-subst\n")
files.append(".gitattributes")
run_command(GITS, ["add", "--"] + files)
@@ -1171,7 +1363,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
"""
rootdirs = []
- for i in range(3):
+ for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {
@@ -1181,9 +1373,8 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
"error": None,
"date": None,
}
- else:
- rootdirs.append(root)
- root = os.path.dirname(root) # up a level
+ rootdirs.append(root)
+ root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix))
@@ -1191,7 +1382,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
SHORT_VERSION_PY = """
-# This file was generated by 'versioneer.py' (0.19) from
+# This file was generated by 'versioneer.py' (0.23) from
# revision-control system data, or from the parent directory name of an
# unpacked source archive. Distribution tarballs contain a pre-generated copy
# of this file.
@@ -1213,7 +1404,7 @@ def versions_from_file(filename):
try:
with open(filename) as f:
contents = f.read()
- except EnvironmentError:
+ except OSError:
raise NotThisMethod("unable to read _version.py")
mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S)
if not mo:
@@ -1264,16 +1455,63 @@ def render_pep440(pieces):
return rendered
+def render_pep440_branch(pieces):
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
+
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
+ (a feature branch will appear "older" than the master branch).
+
+ Exceptions:
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0"
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def pep440_split_post(ver):
+ """Split pep440 version string at the post-release segment.
+
+ Returns the release segments before the post-release and the
+ post-release version number (or -1 if no post-release segment is present).
+ """
+ vc = str.split(ver, ".post")
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
+
+
def render_pep440_pre(pieces):
- """TAG[.post0.devDISTANCE] -- No -dirty.
+ """TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
- rendered = pieces["closest-tag"]
if pieces["distance"]:
- rendered += ".post0.dev%d" % pieces["distance"]
+ # update the post release segment
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
+ rendered = tag_version
+ if post_version is not None:
+ rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
+ else:
+ rendered += ".post0.dev%d" % (pieces["distance"])
+ else:
+ # no commits, use the tag as the version
+ rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
@@ -1307,6 +1545,35 @@ def render_pep440_post(pieces):
return rendered
+def render_pep440_post_branch(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
+
+ The ".dev0" means not master branch.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
@@ -1385,10 +1652,14 @@ def render(pieces, style):
if style == "pep440":
rendered = render_pep440(pieces)
+ elif style == "pep440-branch":
+ rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
+ elif style == "pep440-post-branch":
+ rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
@@ -1496,7 +1767,7 @@ def get_version():
def get_cmdclass(cmdclass=None):
- """Get the custom setuptools/distutils subclasses used by Versioneer.
+ """Get the custom setuptools subclasses used by Versioneer.
If the package uses a different cmdclass (e.g. one from numpy), it
should be provide as an argument.
@@ -1518,8 +1789,8 @@ def get_cmdclass(cmdclass=None):
cmds = {} if cmdclass is None else cmdclass.copy()
- # we add "version" to both distutils and setuptools
- from distutils.core import Command
+ # we add "version" to setuptools
+ from setuptools import Command
class cmd_version(Command):
description = "report generated version string"
@@ -1543,7 +1814,7 @@ def run(self):
cmds["version"] = cmd_version
- # we override "build_py" in both distutils and setuptools
+ # we override "build_py" in setuptools
#
# most invocation pathways end up running build_py:
# distutils/build -> build_py
@@ -1558,13 +1829,14 @@ def run(self):
# then does setup.py bdist_wheel, or sometimes setup.py install
# setup.py egg_info -> ?
+ # pip install -e . and setuptool/editable_wheel will invoke build_py
+ # but the build_py command is not expected to copy any files.
+
# we override different "build_py" commands for both environments
if "build_py" in cmds:
_build_py = cmds["build_py"]
- elif "setuptools" in sys.modules:
- from setuptools.command.build_py import build_py as _build_py
else:
- from distutils.command.build_py import build_py as _build_py
+ from setuptools.command.build_py import build_py as _build_py
class cmd_build_py(_build_py):
def run(self):
@@ -1572,6 +1844,10 @@ def run(self):
cfg = get_config_from_root(root)
versions = get_versions()
_build_py.run(self)
+ if getattr(self, "editable_mode", False):
+ # During editable installs `.py` and data files are
+ # not copied to build_lib
+ return
# now locate _version.py in the new build/ directory and replace
# it with an updated value
if cfg.versionfile_build:
@@ -1581,10 +1857,10 @@ def run(self):
cmds["build_py"] = cmd_build_py
- if "setuptools" in sys.modules:
- from setuptools.command.build_ext import build_ext as _build_ext
+ if "build_ext" in cmds:
+ _build_ext = cmds["build_ext"]
else:
- from distutils.command.build_ext import build_ext as _build_ext
+ from setuptools.command.build_ext import build_ext as _build_ext
class cmd_build_ext(_build_ext):
def run(self):
@@ -1600,7 +1876,14 @@ def run(self):
return
# now locate _version.py in the new build/ directory and replace
# it with an updated value
- target_versionfile = os.path.join(self.build_lib, cfg.versionfile_source)
+ target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build)
+ if not os.path.exists(target_versionfile):
+ print(
+ f"Warning: {target_versionfile} does not exist, skipping "
+ "version update. This can happen if you are running build_ext "
+ "without first running build_py."
+ )
+ return
print("UPDATING %s" % target_versionfile)
write_to_version_file(target_versionfile, versions)
@@ -1672,13 +1955,48 @@ def run(self):
cmds["py2exe"] = cmd_py2exe
+ # sdist farms its file list building out to egg_info
+ if "egg_info" in cmds:
+ _sdist = cmds["egg_info"]
+ else:
+ from setuptools.command.egg_info import egg_info as _egg_info
+
+ class cmd_egg_info(_egg_info):
+ def find_sources(self):
+ # egg_info.find_sources builds the manifest list and writes it
+ # in one shot
+ super().find_sources()
+
+ # Modify the filelist and normalize it
+ root = get_root()
+ cfg = get_config_from_root(root)
+ self.filelist.append("versioneer.py")
+ if cfg.versionfile_source:
+ # There are rare cases where versionfile_source might not be
+ # included by default, so we must be explicit
+ self.filelist.append(cfg.versionfile_source)
+ self.filelist.sort()
+ self.filelist.remove_duplicates()
+
+ # The write method is hidden in the manifest_maker instance that
+ # generated the filelist and was thrown away
+ # We will instead replicate their final normalization (to unicode,
+ # and POSIX-style paths)
+ from setuptools import unicode_utils
+
+ normalized = [unicode_utils.filesys_decode(f).replace(os.sep, "/") for f in self.filelist.files]
+
+ manifest_filename = os.path.join(self.egg_info, "SOURCES.txt")
+ with open(manifest_filename, "w") as fobj:
+ fobj.write("\n".join(normalized))
+
+ cmds["egg_info"] = cmd_egg_info
+
# we override different "sdist" commands for both environments
if "sdist" in cmds:
_sdist = cmds["sdist"]
- elif "setuptools" in sys.modules:
- from setuptools.command.sdist import sdist as _sdist
else:
- from distutils.command.sdist import sdist as _sdist
+ from setuptools.command.sdist import sdist as _sdist
class cmd_sdist(_sdist):
def run(self):
@@ -1742,20 +2060,25 @@ def make_release_tree(self, base_dir, files):
"""
-INIT_PY_SNIPPET = """
+OLD_SNIPPET = """
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
"""
+INIT_PY_SNIPPET = """
+from . import {0}
+__version__ = {0}.get_versions()['version']
+"""
+
def do_setup():
"""Do main VCS-independent setup function for installing Versioneer."""
root = get_root()
try:
cfg = get_config_from_root(root)
- except (EnvironmentError, configparser.NoSectionError, configparser.NoOptionError) as e:
- if isinstance(e, (EnvironmentError, configparser.NoSectionError)):
+ except (OSError, configparser.NoSectionError, configparser.NoOptionError) as e:
+ if isinstance(e, (OSError, configparser.NoSectionError)):
print("Adding sample versioneer config to setup.cfg", file=sys.stderr)
with open(os.path.join(root, "setup.cfg"), "a") as f:
f.write(SAMPLE_CONFIG)
@@ -1781,53 +2104,28 @@ def do_setup():
try:
with open(ipy, "r") as f:
old = f.read()
- except EnvironmentError:
+ except OSError:
old = ""
- if INIT_PY_SNIPPET not in old:
+ module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0]
+ snippet = INIT_PY_SNIPPET.format(module)
+ if OLD_SNIPPET in old:
+ print(" replacing boilerplate in %s" % ipy)
+ with open(ipy, "w") as f:
+ f.write(old.replace(OLD_SNIPPET, snippet))
+ elif snippet not in old:
print(" appending to %s" % ipy)
with open(ipy, "a") as f:
- f.write(INIT_PY_SNIPPET)
+ f.write(snippet)
else:
print(" %s unmodified" % ipy)
else:
print(" %s doesn't exist, ok" % ipy)
ipy = None
- # Make sure both the top-level "versioneer.py" and versionfile_source
- # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so
- # they'll be copied into source distributions. Pip won't be able to
- # install the package without this.
- manifest_in = os.path.join(root, "MANIFEST.in")
- simple_includes = set()
- try:
- with open(manifest_in, "r") as f:
- for line in f:
- if line.startswith("include "):
- for include in line.split()[1:]:
- simple_includes.add(include)
- except EnvironmentError:
- pass
- # That doesn't cover everything MANIFEST.in can do
- # (http://docs.python.org/2/distutils/sourcedist.html#commands), so
- # it might give some false negatives. Appending redundant 'include'
- # lines is safe, though.
- if "versioneer.py" not in simple_includes:
- print(" appending 'versioneer.py' to MANIFEST.in")
- with open(manifest_in, "a") as f:
- f.write("include versioneer.py\n")
- else:
- print(" 'versioneer.py' already in MANIFEST.in")
- if cfg.versionfile_source not in simple_includes:
- print(" appending versionfile_source ('%s') to MANIFEST.in" % cfg.versionfile_source)
- with open(manifest_in, "a") as f:
- f.write("include %s\n" % cfg.versionfile_source)
- else:
- print(" versionfile_source already in MANIFEST.in")
-
# Make VCS-specific changes. For git, this means creating/changing
# .gitattributes to mark _version.py for export-subst keyword
# substitution.
- do_vcs_install(manifest_in, cfg.versionfile_source, ipy)
+ do_vcs_install(cfg.versionfile_source, ipy)
return 0