diff --git a/semi_seg/dsutils.py b/semi_seg/dsutils.py index 1c614328..4148260c 100644 --- a/semi_seg/dsutils.py +++ b/semi_seg/dsutils.py @@ -1,14 +1,15 @@ from copy import deepcopy -from contrastyou import DATA_PATH -from contrastyou.datasets import ACDCSemiInterface, SpleenSemiInterface, ProstateSemiInterface, MMWHSSemiInterface from deepclustering2.dataloader.distributed import InfiniteDistributedSampler from deepclustering2.dataloader.sampler import InfiniteRandomSampler from deepclustering2.dataset import PatientSampler from loguru import logger -from semi_seg.augment import ACDCStrongTransforms, SpleenStrongTransforms, ProstateStrongTransforms from torch.utils.data import DataLoader +from contrastyou import DATA_PATH +from contrastyou.datasets import ACDCSemiInterface, SpleenSemiInterface, ProstateSemiInterface, MMWHSSemiInterface +from semi_seg.augment import ACDCStrongTransforms, SpleenStrongTransforms, ProstateStrongTransforms + dataset_zoos = { "acdc": ACDCSemiInterface, "spleen": SpleenSemiInterface, @@ -36,7 +37,8 @@ def get_dataloaders(config, group_val_patient=True): RuntimeError(f"labeled and unlabeled data should be set properly, " f"given {labeled_data_ratio} and {unlabeled_data_ratio}") data_manager = datainterface(root_dir=DATA_PATH, labeled_data_ratio=labeled_data_ratio, - unlabeled_data_ratio=unlabeled_data_ratio, verbose=False) + unlabeled_data_ratio=unlabeled_data_ratio, verbose=False, + seed=0 if dataset_name == "acdc" else 1)# avoid bad random seed for prostate label_set, unlabel_set, val_set = data_manager._create_semi_supervised_datasets( # noqa labeled_transform=augmentinferface.pretrain, diff --git a/semi_seg/scripts/run_semi b/semi_seg/scripts/run_semi index 4bfda5e6..8271c422 100755 --- a/semi_seg/scripts/run_semi +++ b/semi_seg/scripts/run_semi @@ -104,11 +104,11 @@ function run_mt_infonce() { # repeat python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \ - --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.1 --mt_weight 0.1 --config_path=${config_path} + --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.1 --mt_weight 0.5 --config_path=${config_path} python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \ - --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.2 --mt_weight 0.1 --config_path=${config_path} + --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.2 --mt_weight 0.5 --config_path=${config_path} python run_infonce_semi.py ${comm_cmd} --save_dir ${save_dir}/${prefix} -b ${num_batches} -e ${max_epoch} -s ${rand_seed} \ - --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.5 --mt_weight 0.1 --config_path=${config_path} + --time=4 --arch_checkpoint=$checkpoint meanteacherinfonce --info_weight 0.5 --mt_weight 0.5 --config_path=${config_path} } diff --git a/semi_seg/trainers/_helper.py b/semi_seg/trainers/_helper.py index 084c4526..43603ae3 100644 --- a/semi_seg/trainers/_helper.py +++ b/semi_seg/trainers/_helper.py @@ -2,9 +2,6 @@ from pathlib import Path from typing import List, Dict, Any, Callable -from contrastyou.arch.unet import enable_grad, enable_bn_tracking -from contrastyou.datasets._seg_datset import ContrastBatchSampler # noqa -from contrastyou.helper import get_dataset from deepclustering2.dataset import PatientSampler from deepclustering2.meters2 import StorageIncomeDict, Storage, EpochResultDict from deepclustering2.tqdm import item2str @@ -13,6 +10,10 @@ from torch import nn from torch.utils.data.dataloader import _BaseDataLoaderIter as BaseDataLoaderIter, DataLoader # noqa +from contrastyou.arch.unet import enable_grad, enable_bn_tracking +from contrastyou.datasets._seg_datset import ContrastBatchSampler # noqa +from contrastyou.helper import get_dataset + def _get_contrastive_dataloader(partial_loader, config): # going to get all dataset with contrastive sampler @@ -25,14 +26,21 @@ def _get_contrastive_dataloader(partial_loader, config): contrastive_config = config["ContrastiveLoaderParams"] num_workers = contrastive_config.pop("num_workers") - batch_sampler = ContrastBatchSampler( - dataset=dataset, - **contrastive_config - ) + dataset_name = config["Data"]["name"] + batch_sampler = None + batch_size = contrastive_config["group_sample_num"] * {"acdc": 3, "prostate": 7}[dataset_name] + if dataset_name == "acdc": + # only group the acdc dataset + batch_sampler = ContrastBatchSampler( + dataset=dataset, + **contrastive_config + ) + batch_size = 1 + contrastive_loader = DataLoader( dataset, batch_sampler=batch_sampler, - num_workers=num_workers, - pin_memory=True + num_workers=num_workers, batch_size=batch_size, + pin_memory=True, shuffle=False if batch_sampler else True, ) from contrastyou.augment import ACDCStrongTransforms diff --git a/semi_seg/transfer.sh b/semi_seg/transfer.sh index a5c8418d..60769cfd 100644 --- a/semi_seg/transfer.sh +++ b/semi_seg/transfer.sh @@ -9,9 +9,9 @@ local_folder=./runs # root@jizong.buzz:/root/main/runs/0402_semi/ \ # "${local_folder}/0402_semi_acdc/" -rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \ +rsync -azP --exclude "*/*.png" --exclude "*/tra/*/*.pth" \ --exclude "*/patient*" --exclude "*/*calculquebec.ca" \ - root@jizong.buzz:/root/main/runs/0415_prostate \ + beluga:/lustre04/scratch/jizong/Contrast-You/semi_seg/runs/0415_prostate \ "${local_folder}" #rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \