-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompletion.py
56 lines (48 loc) · 1.46 KB
/
completion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#completion.py
import torch
import random
import torchvision
from model import Model
from dataloader import Dataloader
from checkpoints import Checkpoints
from evaluation import Evaluate
from generation import Generator
import os
import datetime
import utils
import copy
import config
# parse the arguments
args = config.parse_args()
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
args.save = os.path.join(args.result_path, 'save')
args.logs = os.path.join(args.result_path, 'logs')
utils.saveargs(args)
# initialize the checkpoint class
checkpoints = Checkpoints(args)
# Create Model
models = Model(args)
gogan_model, criterion = models.setup(checkpoints)
netD = gogan_model[0]
netG = gogan_model[1]
netE = gogan_model[2]
if args.netD is not '':
checkpointD = checkpoints.load(args.netD)
netD.load_state_dict(checkpointD)
if args.netG is not '':
checkpointG = checkpoints.load(args.netG)
netG.load_state_dict(checkpointG)
if args.netE is not '':
checkpointE = checkpoints.load(args.netE)
netE.load_state_dict(checkpointE)
# Data Loading
dataloader = Dataloader(args)
test_loader = dataloader.create("Test", shuffle=False)
# The trainer handles the training loop and evaluation on validation set
# evaluate = Evaluate(args, netD, netG, netE)
generator = Generator(args, netD, netG, netE)
# test for a single epoch
# test_loss = evaluate.complete(test_loader)
# loss = generator.generate_one(test_loader)
loss = generator.interpolate(test_loader)