This repo contains the supported pytorch code and configuration files to reproduce the results of the Co-Manifold Learning for Semi-supervised Medical Image Segmentation Article.
In this study, we investigate jointly learning Hyperbolic and Euclidean space representations and match the consistency for semi-supervised medical image segmentation. We argue that for complex medical volumetric data, hyperbolic spaces are beneficial to model data inductive biases. We propose an approach incorporating the two geometries to co-train a variational encoder-decoder model with a Hyperbolic probabilistic latent space and a separate variational encoder-decoder model with a Euclidean probabilistic latent space with complementary representations, thereby bridging the gap of co-training across manifolds (Co-Manifold learning) in a principled manner. To capture complementary information and hierarchical relationships, we propose a latent space embedding loss aimed at maximizing disagreement between embeddings across manifolds. Additionally, we employ adversarial learning to enhance segmentation performance by guiding the network in hyperbolic latent space using confident regions identified by the network in Euclidean space. Conversely, the network in Euclidean space is informed by hyperbolic uncertainty, creating a dual uncertainty-aware framework that enables the two spaces to collaboratively learn confident regions from each other. Our proposed method achieves competitive results on two benchmarks for semi-supervised medical image segmentation on medical scans.
To be Added
Under this section, we provide details on environmental setup and dependencies required to train/test the Co-BioNet model.
This software was originally designed and run on a system running Ubuntu (Compatible with Windows 11 as well).
All the experiments are conducted on Ubuntu 20.04 Focal version with Python 3.8.
To train Co-Manifold with the given settings, the system requires a GPU with at least 40GB. All the experiments are conducted on Nvidia A40 single GPU.
(Not required any non-standard hardware)
pip install virtualenv
virtualenv -p /usr/bin/python3.8 venv
source venv/bin/activate
- Install torch :
pip3 install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
- Install other dependencies :
pip install -r requirements.txt
The experiments are conducted on two publicly available datasets,
- 2018 Left Atrial Segmentation Challenge Dataset : http://atriaseg2018.cardiacatlas.org
- MSD BraTS Dataset : http://medicaldecathlon.com/
Pre-processed data can be found in folder data.
Download trained model weights from this shared drive link, and put it under folder code_la/model_weights or code_brats/model_weights
- To train the model for LA MRI dataset on 10% Lableled data
cd code_la
CUDA_VISIBLE_DEVICES=0 nohup python train.py --dataset_name "LA" --labelnum 8 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.2 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_10.out &
- To train the model for LA MRI dataset on 20% Lableled data
cd code_la
CUDA_VISIBLE_DEVICES=0 nohup python train.py --dataset_name "LA" --max_iteration 30000 --labelnum 16 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.1 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_20.out &
- To train the model for MSD BraTS MRI dataset on 10% Lableled data
cd code_brats
nohup python train.py --dataset_name MSD_BRATS --labelnum 39 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.02 --t_m 0.2 --batch_size 6 --labeled_bs 3 &> msd_10_perc.out &
- To train the model for MSD BraTS MRI dataset on 20% Lableled data
cd code_brats
nohup python train.py --dataset_name MSD_BRATS --labelnum 77 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.02 --t_m 0.2 --batch_size 6 --labeled_bs 3 &> msd_20_perc.out &
- To test the Co-Manifold ensemble model for LA MRI dataset on 10% Lableled data
cd code
CUDA_VISIBLE_DEVICES=0 nohup python inference.py --dataset_name "LA" --labelnum 8 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.2 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_10_eval.out &
This repository makes liberal use of code from capturing-implicit-hierarchical-structure and MC-Net
If you find this repository useful, please consider giving us a star ⭐