This is the PyTorch implementation of our MICCAI 2024 paper "Robust Semi-Supervised Multimodal Medical Image Segmentation via Cross Modality Collaboration" by Xiaogen Zhou, Yiyou Sun, Min Deng, Winnie Chiu Wing Chu and Qi Dou*.
* denotes corresponding authors.
Multimodal learning leverages complementary information derived from different modalities, thereby enhancing performance in med ical image segmentation. However, prevailing multimodal learning meth ods heavily rely on extensive well-annotated data from various modal ities to achieve accurate segmentation performance. This dependence often poses a challenge in clinical settings due to limited availability of such data. Moreover, the inherent anatomical misalignment between different imaging modalities further complicates the endeavor to en hance segmentation performance. To address this problem, we propose a novel semi-supervised multimodal segmentation framework that is ro bust to scarce labeled data and misaligned modalities. Our framework employs a novel cross modality collaboration strategy to distill modality independent knowledge, which is inherently associated with each modal ity, and integrates this information into a unified fusion layer for fea ture amalgamation. With a channel-wise semantic consistency loss, our framework ensures alignment of modality-independent information from a feature-wise perspective across modalities, thereby fortifying it against misalignments in multimodal scenarios. Furthermore, our framework ef fectively integrates contrastive consistent learning to regulate anatomi cal structures, facilitating anatomical-wise prediction alignment on unla beled data in semi-supervised segmentation tasks. Our method achieves competitive performance compared to other multimodal methods across three tasks: cardiac, abdominal multi-organ, and thyroid-associated or bitopathy segmentations. It also demonstrates outstanding robustness in scenarios involving scarce labeled data and misaligned modalities.
-
Download from GitHub
git clone https://github.com/med-air/CMC.git cd CMC
-
Create conda environment
conda create --name CMC python=3.8.18 conda activate CMC pip install -r requirements.txt # CUDA 11.8 conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
Note: You can download our datasets as follows, please download our pre-processing dataset of AMOS from here and put them into the folder 'dataset':
1. MS-CMRSeg 2019 dataset: here
2. AMOS Dataset: here
Our encoder and decoder use a Foundation model's [link] pre-trained weights [link] and pre-trained weights [link] in SAM-Med3D[link]. You also can download them from here Please download them and put them into the folder 'pretrain_model' before running the following script.
#### Training stage
python main.py --backbone 'Foundation_model' --batch_size 4 --img_size 96
#### Testing stage
python test.py --backbone 'Foundation_model'
We also provide our model checkpoints for the experiments on the AMOS dataset as listed below (Mean Dice is the evaluation metric).
Training | CT (Mean Dice(%)) | MRI (Mean Dice(%)) | Checkpoint |
---|---|---|---|
10% Labeled data | 76.28 | 84.27 | [checkpoint]) |
20% Labeled data | 84.57 | 89.05 | [checkpoint] |
Note: Please download these checkpoints and put them into the folder 'checkpoint', then run the following script for testing to reproduce our experimental results.
python test.py --backbone 'Foundation_model'
If this repository is useful for your research, please cite:
@article{2024cmc,
title={Robust Semi-Supervised Multimodal Medical Image Segmentation via Cross Modality Collaboration},
author={Xiaogen Zhou, Yiyou Sun, Min Deng,
Winnie Chiu Wing Chu and Qi Dou},
journal={International Conference on Medical Image Computing and Computer Assisted Intervention},
year={2024}
}
If you have any questions, please feel free to leave issues here, or contact ‘[email protected]’