-
Notifications
You must be signed in to change notification settings - Fork 4
/
default_configuration.py
83 lines (69 loc) · 3.86 KB
/
default_configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nnUNetTrainerV2 import nnUNetTrainerV2
from nnUNetTrainer import nnUNetTrainer
import nnunet
from batchgenerators.utilities.file_and_folder_operations import *
# from nnunet.experiment_planning.summarize_plans import summarize_plans
from model_restore import recursive_find_python_class
from configuration import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
def get_configuration_from_output_folder(folder):
# split off network_training_output_dir
folder = folder[len(network_training_output_dir):]
if folder.startswith("/"):
folder = folder[1:]
configuration, task, trainer_and_plans_identifier = folder.split("/")
trainer, plans_identifier = trainer_and_plans_identifier.split("__")
return configuration, task, trainer, plans_identifier
def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier,
search_in=("/home/SENSETIME/luoxiangde.vendor/Projects/nnUNet_mini", "training", "network_training"),
base_module='nnunet.training.network_training'):
assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \
"network can only be one of the following: \'3d\', \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'"
dataset_directory = join(preprocessing_output_dir, task)
if network == '2d':
plans_file = join(preprocessing_output_dir, task,
plans_identifier + "_plans_2D.pkl")
else:
plans_file = join(preprocessing_output_dir, task,
plans_identifier + "_plans_3D.pkl")
plans = load_pickle(plans_file)
possible_stages = list(plans['plans_per_stage'].keys())
if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1:
raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does "
"not require the cascade. Run 3d_fullres instead")
if network == '2d' or network == "3d_lowres":
stage = 0
else:
stage = possible_stages[-1]
print([join(*search_in)], network_trainer, base_module)
trainer_class = nnUNetTrainer
output_folder_name = join(network_training_output_dir,
network, task, network_trainer + "__" + plans_identifier)
print("###############################################")
print("I am running the following nnUNet: %s" % network)
print("My trainer class is: ", trainer_class)
print("For that I will be using the following configuration:")
# summarize_plans(plans_file)
print("I am using stage %d from these plans" % stage)
if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres':
batch_dice = True
print("I am using batch dice + CE loss")
else:
batch_dice = False
print("I am using sample dice + CE loss")
print("\nI am using data from this folder: ", join(
dataset_directory, plans['data_identifier']))
print("###############################################")
return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class