From c44b4dc7b936f7777a6d0683af1428fca7120e62 Mon Sep 17 00:00:00 2001 From: jizong Date: Wed, 14 Apr 2021 19:20:44 +0800 Subject: [PATCH] increase the learning rate of prostate fixing a bug with infinite sampler --- semi_seg/__init__.py | 2 +- semi_seg/dsutils.py | 2 +- semi_seg/scripts/helper.py | 2 +- semi_seg/trainers/_helper.py | 10 +++++++--- semi_seg/transfer.sh | 4 ++-- 5 files changed, 12 insertions(+), 8 deletions(-) diff --git a/semi_seg/__init__.py b/semi_seg/__init__.py index 90690139..4cefec66 100644 --- a/semi_seg/__init__.py +++ b/semi_seg/__init__.py @@ -1,5 +1,5 @@ acdc_ratios = [0.01, 0.015, 0.025, 1.0] -prostate_ratio = [0.08, 0.1, 0.13, 0.18, 1.0] # 3 5 7 +prostate_ratio = [0.08, 0.13, 0.18, 1.0] # 3 5 7 ratio_zoom = {"acdc": acdc_ratios, "prostate": prostate_ratio} diff --git a/semi_seg/dsutils.py b/semi_seg/dsutils.py index 4148260c..36fd7c9a 100644 --- a/semi_seg/dsutils.py +++ b/semi_seg/dsutils.py @@ -38,7 +38,7 @@ def get_dataloaders(config, group_val_patient=True): f"given {labeled_data_ratio} and {unlabeled_data_ratio}") data_manager = datainterface(root_dir=DATA_PATH, labeled_data_ratio=labeled_data_ratio, unlabeled_data_ratio=unlabeled_data_ratio, verbose=False, - seed=0 if dataset_name == "acdc" else 1)# avoid bad random seed for prostate + seed=0 if dataset_name == "acdc" else 10)# avoid bad random seed for prostate label_set, unlabel_set, val_set = data_manager._create_semi_supervised_datasets( # noqa labeled_transform=augmentinferface.pretrain, diff --git a/semi_seg/scripts/helper.py b/semi_seg/scripts/helper.py index 3f0d5316..29f4d70d 100644 --- a/semi_seg/scripts/helper.py +++ b/semi_seg/scripts/helper.py @@ -12,7 +12,7 @@ "mmwhs": 5, } ft_lr_zooms = {"acdc": 0.0000001, - "prostate": 0.00000025, + "prostate": 0.0000005, "spleen": 0.000001, "mmwhs": 0.000001} diff --git a/semi_seg/trainers/_helper.py b/semi_seg/trainers/_helper.py index 43603ae3..68434daf 100644 --- a/semi_seg/trainers/_helper.py +++ b/semi_seg/trainers/_helper.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import List, Dict, Any, Callable +from deepclustering2.dataloader.sampler import InfiniteRandomSampler from deepclustering2.dataset import PatientSampler from deepclustering2.meters2 import StorageIncomeDict, Storage, EpochResultDict from deepclustering2.tqdm import item2str @@ -29,18 +30,21 @@ def _get_contrastive_dataloader(partial_loader, config): dataset_name = config["Data"]["name"] batch_sampler = None batch_size = contrastive_config["group_sample_num"] * {"acdc": 3, "prostate": 7}[dataset_name] + sampler = InfiniteRandomSampler(dataset, shuffle=True) + if dataset_name == "acdc": # only group the acdc dataset batch_sampler = ContrastBatchSampler( dataset=dataset, **contrastive_config - ) + ) # this batch sampler is without end batch_size = 1 + sampler = None contrastive_loader = DataLoader( - dataset, batch_sampler=batch_sampler, + dataset, batch_sampler=batch_sampler, sampler=sampler, num_workers=num_workers, batch_size=batch_size, - pin_memory=True, shuffle=False if batch_sampler else True, + pin_memory=True, shuffle=False, ) from contrastyou.augment import ACDCStrongTransforms diff --git a/semi_seg/transfer.sh b/semi_seg/transfer.sh index 60769cfd..3d3dc210 100644 --- a/semi_seg/transfer.sh +++ b/semi_seg/transfer.sh @@ -10,8 +10,8 @@ local_folder=./runs # "${local_folder}/0402_semi_acdc/" rsync -azP --exclude "*/*.png" --exclude "*/tra/*/*.pth" \ - --exclude "*/patient*" --exclude "*/*calculquebec.ca" \ - beluga:/lustre04/scratch/jizong/Contrast-You/semi_seg/runs/0415_prostate \ + --exclude "*/patient*" \ + root@jizong.buzz:/root/main/runs/0416_prostate \ "${local_folder}" #rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \