Skip to content

Official PyTorch implementation for Co-Manifold Learning for Semi-supervised Medical Image Segmentation

License

Notifications You must be signed in to change notification settings

himashi92/Co-Manifold

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Co-Manifold

This repo contains the supported pytorch code and configuration files to reproduce the results of the Co-Manifold Learning for Semi-supervised Medical Image Segmentation Article.

Abstract

In this study, we investigate jointly learning Hyperbolic and Euclidean space representations and match the consistency for semi-supervised medical image segmentation. We argue that for complex medical volumetric data, hyperbolic spaces are beneficial to model data inductive biases. We propose an approach incorporating the two geometries to co-train a variational encoder-decoder model with a Hyperbolic probabilistic latent space and a separate variational encoder-decoder model with a Euclidean probabilistic latent space with complementary representations, thereby bridging the gap of co-training across manifolds (Co-Manifold learning) in a principled manner. To capture complementary information and hierarchical relationships, we propose a latent space embedding loss aimed at maximizing disagreement between embeddings across manifolds. Additionally, we employ adversarial learning to enhance segmentation performance by guiding the network in hyperbolic latent space using confident regions identified by the network in Euclidean space. Conversely, the network in Euclidean space is informed by hyperbolic uncertainty, creating a dual uncertainty-aware framework that enables the two spaces to collaboratively learn confident regions from each other. Our proposed method achieves competitive results on two benchmarks for semi-supervised medical image segmentation on medical scans.

Link to full paper:

To be Added

Proposed Architecture

Proposed Architecture

System requirements

Under this section, we provide details on environmental setup and dependencies required to train/test the Co-BioNet model. This software was originally designed and run on a system running Ubuntu (Compatible with Windows 11 as well).
All the experiments are conducted on Ubuntu 20.04 Focal version with Python 3.8.
To train Co-Manifold with the given settings, the system requires a GPU with at least 40GB. All the experiments are conducted on Nvidia A40 single GPU. (Not required any non-standard hardware)

Create a virtual environment

pip install virtualenv
virtualenv -p /usr/bin/python3.8 venv
source venv/bin/activate

Installation guide

  • Install torch :
pip3 install torch==1.10.2+cu113 torchvision==0.11.3+cu113 torchaudio==0.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
  • Install other dependencies :
pip install -r requirements.txt

Dataset Preparation

The experiments are conducted on two publicly available datasets,

Pre-processed data can be found in folder data.

Trained Model Weights

Download trained model weights from this shared drive link, and put it under folder code_la/model_weights or code_brats/model_weights

Train Model

  • To train the model for LA MRI dataset on 10% Lableled data
cd code_la
CUDA_VISIBLE_DEVICES=0 nohup python train.py --dataset_name "LA" --labelnum 8 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.2 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_10.out &
  • To train the model for LA MRI dataset on 20% Lableled data
cd code_la
CUDA_VISIBLE_DEVICES=0 nohup python train.py --dataset_name "LA" --max_iteration 30000 --labelnum 16 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.1 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_20.out &
  • To train the model for MSD BraTS MRI dataset on 10% Lableled data
cd code_brats
nohup python train.py --dataset_name MSD_BRATS --labelnum 39 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.02 --t_m 0.2 --batch_size 6 --labeled_bs 3 &> msd_10_perc.out &
  • To train the model for MSD BraTS MRI dataset on 20% Lableled data
cd code_brats
nohup python train.py --dataset_name MSD_BRATS --labelnum 77 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.02 --t_m 0.2 --batch_size 6 --labeled_bs 3 &> msd_20_perc.out &

Test Model

  • To test the Co-Manifold ensemble model for LA MRI dataset on 10% Lableled data
cd code
CUDA_VISIBLE_DEVICES=0 nohup python inference.py --dataset_name "LA" --labelnum 8 --dl_w 1.0 --ce_w 1.0 --alpha 0.005 --beta 0.2 --t_m 0.1 --hidden-dim 256 --batch_size 4 --labeled_bs 2 &> la_10_eval.out &

Acknowledgements

This repository makes liberal use of code from capturing-implicit-hierarchical-structure and MC-Net

Citing Co-Manifolds

If you find this repository useful, please consider giving us a star ⭐