Skip to content

Commit

Permalink
increase the learning rate of prostate
Browse files Browse the repository at this point in the history
fixing a bug with infinite sampler
  • Loading branch information
jizong committed Apr 14, 2021
1 parent def5bdc commit c44b4dc
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion semi_seg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
acdc_ratios = [0.01, 0.015, 0.025, 1.0]
prostate_ratio = [0.08, 0.1, 0.13, 0.18, 1.0] # 3 5 7
prostate_ratio = [0.08, 0.13, 0.18, 1.0] # 3 5 7

ratio_zoom = {"acdc": acdc_ratios,
"prostate": prostate_ratio}
2 changes: 1 addition & 1 deletion semi_seg/dsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_dataloaders(config, group_val_patient=True):
f"given {labeled_data_ratio} and {unlabeled_data_ratio}")
data_manager = datainterface(root_dir=DATA_PATH, labeled_data_ratio=labeled_data_ratio,
unlabeled_data_ratio=unlabeled_data_ratio, verbose=False,
seed=0 if dataset_name == "acdc" else 1)# avoid bad random seed for prostate
seed=0 if dataset_name == "acdc" else 10)# avoid bad random seed for prostate

label_set, unlabel_set, val_set = data_manager._create_semi_supervised_datasets( # noqa
labeled_transform=augmentinferface.pretrain,
Expand Down
2 changes: 1 addition & 1 deletion semi_seg/scripts/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"mmwhs": 5,
}
ft_lr_zooms = {"acdc": 0.0000001,
"prostate": 0.00000025,
"prostate": 0.0000005,
"spleen": 0.000001,
"mmwhs": 0.000001}

Expand Down
10 changes: 7 additions & 3 deletions semi_seg/trainers/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import List, Dict, Any, Callable

from deepclustering2.dataloader.sampler import InfiniteRandomSampler
from deepclustering2.dataset import PatientSampler
from deepclustering2.meters2 import StorageIncomeDict, Storage, EpochResultDict
from deepclustering2.tqdm import item2str
Expand Down Expand Up @@ -29,18 +30,21 @@ def _get_contrastive_dataloader(partial_loader, config):
dataset_name = config["Data"]["name"]
batch_sampler = None
batch_size = contrastive_config["group_sample_num"] * {"acdc": 3, "prostate": 7}[dataset_name]
sampler = InfiniteRandomSampler(dataset, shuffle=True)

if dataset_name == "acdc":
# only group the acdc dataset
batch_sampler = ContrastBatchSampler(
dataset=dataset,
**contrastive_config
)
) # this batch sampler is without end
batch_size = 1
sampler = None

contrastive_loader = DataLoader(
dataset, batch_sampler=batch_sampler,
dataset, batch_sampler=batch_sampler, sampler=sampler,
num_workers=num_workers, batch_size=batch_size,
pin_memory=True, shuffle=False if batch_sampler else True,
pin_memory=True, shuffle=False,
)

from contrastyou.augment import ACDCStrongTransforms
Expand Down
4 changes: 2 additions & 2 deletions semi_seg/transfer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ local_folder=./runs
# "${local_folder}/0402_semi_acdc/"

rsync -azP --exclude "*/*.png" --exclude "*/tra/*/*.pth" \
--exclude "*/patient*" --exclude "*/*calculquebec.ca" \
beluga:/lustre04/scratch/jizong/Contrast-You/semi_seg/runs/0415_prostate \
--exclude "*/patient*" \
[email protected]:/root/main/runs/0416_prostate \
"${local_folder}"

#rsync -azP --exclude "*/*.png" --exclude "*/*.pth" \
Expand Down

0 comments on commit c44b4dc

Please sign in to comment.