forked from matheusgadelha/MRTNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_img2pc.py
66 lines (51 loc) · 2.37 KB
/
run_img2pc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import argparse
import os
from tools.Trainer import ImageToPCTrainer
from tools.PointCloudDataset import ImageDataset
from models.AutoEncoder import PointCloudVAE
from models.AutoEncoder import ChamferLoss
from models.AutoEncoder import ChamferWithNormalLoss
from models.AutoEncoder import L2WithNormalLoss
from models.ImageToShape import MultiResImageToShape
from models.ImageToShape import SingleResImageToShape
from models.ImageToShape import FCImageToShape
parser = argparse.ArgumentParser(description='MultiResolution image to shape model.')
parser.add_argument("-n", "--name", type=str, help="Name of the experiment.", default="MRI2PC")
parser.add_argument("-a", "--arch", type=str, help="Encoder architecture.", default="vgg")
parser.add_argument("-pt", "--pretrained", type=str, help="Use pretrained net", default="True")
parser.add_argument("-c", "--category", type=str, help="Category code (all is possible)", default="all")
parser.add_argument("--train", dest='train', action='store_true')
parser.set_defaults(train=False)
#Change this for a path with images you want to test
image_datapath = "notseen_real"
if __name__ == '__main__':
args = parser.parse_args()
ptrain = None
if args.pretrained == "False":
ptrain = False
elif args.pretrained == "True":
ptrain = True
full_name = "{}_{}_{}_{}".format(args.name, args.category, args.arch, ptrain)
#full_name = args.name
print full_name
#mri2pc = FCImageToShape(size=4096, dim=3, batch_size=1,
# name=full_name, pretrained=ptrain, arch=args.arch)
#mri2pc = SingleResImageToShape(size=4096, dim=3, batch_size=1,
# name=full_name, pretrained=ptrain, arch=args.arch)
mri2pc = MultiResImageToShape(size=4096, dim=3, batch_size=1,
name=full_name, pretrained=ptrain, arch=args.arch)
mri2pc.load('checkpoint')
optimizer = optim.Adam(mri2pc.parameters(), lr=0.001)
test_dataset = ImageDataset(image_datapath)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
shuffle=True, num_workers=2)
log_dir = os.path.join("log", full_name)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
trainer = ImageToPCTrainer(mri2pc, None, test_loader,
optimizer, ChamferLoss(), log_dir=log_dir)
trainer.run()