-
Notifications
You must be signed in to change notification settings - Fork 4
/
model_restore.py
156 lines (134 loc) · 6.95 KB
/
model_restore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnunet
import torch
from batchgenerators.utilities.file_and_folder_operations import *
import importlib
import pkgutil
from nnUNetTrainer import nnUNetTrainer
def recursive_find_python_class(folder, trainer_name, current_module):
tr = None
for importer, modname, ispkg in pkgutil.iter_modules(folder):
# print(modname, ispkg)
if not ispkg:
m = importlib.import_module(current_module + "." + modname)
if hasattr(m, trainer_name):
tr = getattr(m, trainer_name)
break
if tr is None:
for importer, modname, ispkg in pkgutil.iter_modules(folder):
if ispkg:
next_current_module = current_module + "." + modname
tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module)
if tr is not None:
break
return tr
def restore_model(pkl_file, checkpoint=None, train=False, fp16=None):
"""
This is a utility function to load any nnUNet trainer from a pkl. It will recursively search
nnunet.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint
is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train).
The pkl file required here is the one that will be saved automatically when calling nnUNetTrainer.save_checkpoint.
:param pkl_file:
:param checkpoint:
:param train:
:param fp16: if None then we take no action. If True/False we overwrite what the model has in its init
:return:
"""
info = load_pickle(pkl_file)
init = info['init']
name = info['name']
# search_in = join(nnunet.__path__[0], "training", "network_training")
# tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training")
tr = nnUNetTrainer
if tr is None:
"""
Fabian only. This will trigger searching for trainer classes in other repositories as well
"""
try:
import meddec
search_in = join(meddec.__path__[0], "model_training")
tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training")
except ImportError:
pass
if tr is None:
raise RuntimeError("Could not find the model trainer specified in checkpoint in nnunet.trainig.network_training. If it "
"is not located there, please move it or change the code of restore_model. Your model "
"trainer can be located in any directory within nnunet.trainig.network_training (search is recursive)."
"\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name))
assert issubclass(tr, nnUNetTrainer), "The network trainer was found but is not a subclass of nnUNetTrainer. " \
"Please make it so!"
# this is now deprecated
"""if len(init) == 7:
print("warning: this model seems to have been saved with a previous version of nnUNet. Attempting to load it "
"anyways. Expect the unexpected.")
print("manually editing init args...")
init = [init[i] for i in range(len(init)) if i != 2]"""
# ToDo Fabian make saves use kwargs, please...
trainer = tr(*init)
# We can hack fp16 overwriting into the trainer without changing the init arguments because nothing happens with
# fp16 in the init, it just saves it to a member variable
if fp16 is not None:
trainer.fp16 = fp16
trainer.process_plans(info['plans'])
if checkpoint is not None:
trainer.load_checkpoint(checkpoint, train)
return trainer
def load_best_model_for_inference(folder):
checkpoint = join(folder, "model_best.model")
pkl_file = checkpoint + ".pkl"
return restore_model(pkl_file, checkpoint, False)
def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"):
"""
used for if you need to ensemble the five models of a cross-validation. This will restore the model from the
checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast
switching between parameters (as opposed to loading them form disk each time).
This is best used for inference and test prediction
:param folder:
:param folds:
:param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
:return:
"""
if isinstance(folds, str):
folds = [join(folder, "all")]
assert isdir(folds[0]), "no output folder for fold %s found" % folds
elif isinstance(folds, (list, tuple)):
if len(folds) == 1 and folds[0] == "all":
folds = [join(folder, "all")]
else:
folds = [join(folder, "fold_%d" % i) for i in folds]
assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present"
elif isinstance(folds, int):
folds = [join(folder, "fold_%d" % folds)]
assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds
elif folds is None:
print("folds is None so we will automatically look for output folders (not using \'all\'!)")
folds = subfolders(folder, prefix="fold")
print("found the following folds: ", folds)
else:
raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds)))
trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision)
trainer.output_folder = folder
trainer.output_folder_base = folder
trainer.update_fold(0)
trainer.initialize(False)
all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds]
print("using the following model files: ", all_best_model_files)
all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files]
return trainer, all_params
if __name__ == "__main__":
pkl = "/home/fabian/PhD/results/nnUNetV2/nnUNetV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl"
checkpoint = pkl[:-4]
train = False
trainer = restore_model(pkl, checkpoint, train)