Skip to content

Commit

Permalink
remove contrastive sampler for prostate dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Apr 14, 2021
1 parent 9075023 commit def5bdc
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
10 changes: 6 additions & 4 deletions semi_seg/dsutils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from copy import deepcopy

from contrastyou import DATA_PATH
from contrastyou.datasets import ACDCSemiInterface, SpleenSemiInterface, ProstateSemiInterface, MMWHSSemiInterface
from deepclustering2.dataloader.distributed import InfiniteDistributedSampler
from deepclustering2.dataloader.sampler import InfiniteRandomSampler
from deepclustering2.dataset import PatientSampler
from loguru import logger
from semi_seg.augment import ACDCStrongTransforms, SpleenStrongTransforms, ProstateStrongTransforms
from torch.utils.data import DataLoader

from contrastyou import DATA_PATH
from contrastyou.datasets import ACDCSemiInterface, SpleenSemiInterface, ProstateSemiInterface, MMWHSSemiInterface
from semi_seg.augment import ACDCStrongTransforms, SpleenStrongTransforms, ProstateStrongTransforms

dataset_zoos = {
"acdc": ACDCSemiInterface,
"spleen": SpleenSemiInterface,
Expand Down Expand Up @@ -36,7 +37,8 @@ def get_dataloaders(config, group_val_patient=True):
RuntimeError(f"labeled and unlabeled data should be set properly, "
f"given {labeled_data_ratio} and {unlabeled_data_ratio}")
data_manager = datainterface(root_dir=DATA_PATH, labeled_data_ratio=labeled_data_ratio,
unlabeled_data_ratio=unlabeled_data_ratio, verbose=False)
unlabeled_data_ratio=unlabeled_data_ratio, verbose=False,
seed=0 if dataset_name == "acdc" else 1)# avoid bad random seed for prostate

label_set, unlabel_set, val_set = data_manager._create_semi_supervised_datasets( # noqa
labeled_transform=augmentinferface.pretrain,
Expand Down
6 changes: 3 additions & 3 deletions semi_seg/scripts/run_semi
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ function run_mt_infonce() {

# repeat
python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.1 --mt_weight 0.1 --config_path=${config_path}
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.1 --mt_weight 0.5 --config_path=${config_path}

python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.2 --mt_weight 0.1 --config_path=${config_path}
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.2 --mt_weight 0.5 --config_path=${config_path}

python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.5 --mt_weight 0.1 --config_path=${config_path}
--time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.5 --mt_weight 0.5 --config_path=${config_path}
}
26 changes: 17 additions & 9 deletions semi_seg/trainers/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from pathlib import Path
from typing import List, Dict, Any, Callable

from contrastyou.arch.unet import enable_grad, enable_bn_tracking
from contrastyou.datasets._seg_datset import ContrastBatchSampler # noqa
from contrastyou.helper import get_dataset
from deepclustering2.dataset import PatientSampler
from deepclustering2.meters2 import StorageIncomeDict, Storage, EpochResultDict
from deepclustering2.tqdm import item2str
Expand All @@ -13,6 +10,10 @@
from torch import nn
from torch.utils.data.dataloader import _BaseDataLoaderIter as BaseDataLoaderIter, DataLoader # noqa

from contrastyou.arch.unet import enable_grad, enable_bn_tracking
from contrastyou.datasets._seg_datset import ContrastBatchSampler # noqa
from contrastyou.helper import get_dataset


def _get_contrastive_dataloader(partial_loader, config):
# going to get all dataset with contrastive sampler
Expand All @@ -25,14 +26,21 @@ def _get_contrastive_dataloader(partial_loader, config):

contrastive_config = config["ContrastiveLoaderParams"]
num_workers = contrastive_config.pop("num_workers")
batch_sampler = ContrastBatchSampler(
dataset=dataset,
**contrastive_config
)
dataset_name = config["Data"]["name"]
batch_sampler = None
batch_size = contrastive_config["group_sample_num"] * {"acdc": 3, "prostate": 7}[dataset_name]
if dataset_name == "acdc":
# only group the acdc dataset
batch_sampler = ContrastBatchSampler(
dataset=dataset,
**contrastive_config
)
batch_size = 1

contrastive_loader = DataLoader(
dataset, batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True
num_workers=num_workers, batch_size=batch_size,
pin_memory=True, shuffle=False if batch_sampler else True,
)

from contrastyou.augment import ACDCStrongTransforms
Expand Down
4 changes: 2 additions & 2 deletions semi_seg/transfer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ local_folder=./runs
# [email protected]:/root/main/runs/0402_semi/ \
# "${local_folder}/0402_semi_acdc/"

rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \
rsync -azP --exclude "*/*.png" --exclude "*/tra/*/*.pth" \
--exclude "*/patient*" --exclude "*/*calculquebec.ca" \
[email protected]:/root/main/runs/0415_prostate \
beluga:/lustre04/scratch/jizong/Contrast-You/semi_seg/runs/0415_prostate \
"${local_folder}"

#rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \
Expand Down

0 comments on commit def5bdc

Please sign in to comment.