-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate_hints_median_matching.py
105 lines (90 loc) · 3.61 KB
/
evaluate_hints_median_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
import os
import torch
import json
from torch.utils.data import DataLoader
from utils.train_utils import init_randomness, worker_init_randomness
from utils.eval_utils import evaluate_model_on_dataset, evaluate_model_on_data_entry
from models.core.checkpoint import load_checkpoint, safe_makedir
from models import make_model
from sacred import Experiment
from sacred.observers import FileStorageObserver
from time import perf_counter
# Dataset
from models.data.nyuv2_official_nohints_sid_dataset import nyuv2_nohints_sid_ingredient, load_data
ex = Experiment('eval_median_matching', ingredients=[nyuv2_nohints_sid_ingredient])
# Tensorboardx
# writer = SummaryWriter()
@ex.config
def cfg(data_config):
model_config = { # Load pretrained model for testing
"model_name": "DORN_median_matching",
"model_params": {
"in_channels": 3,
"in_height": 257,
"in_width": 353,
"sid_bins": data_config["sid_bins"],
"offset": data_config["offset"],
"min_depth": data_config["min_depth"],
"max_depth": data_config["max_depth"],
"alpha": data_config["alpha"],
"beta": data_config["beta"],
"frozen": True,
"pretrained": True,
"state_dict_file": os.path.join("models", "torch_params_nyuv2_BGR.pth.tar"),
},
"model_state_dict_fn": None
}
ckpt_file = None # Median matching eval
dataset_type = "val"
save_outputs = True
seed = 95290421
small_run = 0
entry = None
# print(data_config.keys())
output_dir = os.path.join("results",
data_config["data_name"], # e.g. nyu_depth_v2
"{}_{}".format(dataset_type, small_run),
model_config["model_name"]) # e.g. DORN_nyu_nohints
safe_makedir(output_dir)
ex.observers.append(FileStorageObserver.create(os.path.join(output_dir, "runs")))
##
cuda_device = "0" # The gpu index to run on. Should be a string
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
# print("after: {}".format(os.environ["CUDA_VISIBLE_DEVICES"]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using device: {} (CUDA_VISIBLE_DEVICES = {})".format(device,
os.environ["CUDA_VISIBLE_DEVICES"]))
if ckpt_file is not None:
model_update, _, _ = load_checkpoint(ckpt_file)
model_config.update(model_update)
del model_update, _ # So sacred doesn't collect them.
@ex.automain
def main(model_config,
dataset_type,
save_outputs,
output_dir,
data_config,
seed,
small_run,
entry,
device):
# Load the model
model = make_model(**model_config)
model.to(device)
model.eval()
# model.sid_obj.to(device)
print(model)
# Load the data
_, val, test = load_data()
dataset = test if dataset_type == "test" else val
init_randomness(seed)
print("Evaluating the model on {} ({})".format(data_config["data_name"],
dataset_type))
if entry is None:
print("Evaluating the model on {} ({})".format(data_config["data_name"],
dataset_type))
evaluate_model_on_dataset(model, dataset, small_run, device, save_outputs, output_dir)
else:
print("Evaluating {}".format(entry))
evaluate_model_on_data_entry(model, dataset, entry, device)