-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
49 lines (42 loc) · 1.43 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import yaml
import os
import torch
from options.test_options import CycleTestOptions
from models.cyclegan import CycleGAN
from data.dataset import CycleDataset
from utils.model_utils import save_outs
if __name__ == '__main__':
# parse options
parser = CycleTestOptions()
opt = parser.parse()
opt.phase = 'test'
parser.export_options(opt)
result_dir = os.path.join(opt.result_dir, f"{opt.model_name}_{opt.phase}")
os.makedirs(result_dir, exist_ok=True)
# config params
with open(opt.config, 'r') as file:
config = yaml.safe_load(file)
config['dataset']['scale_size'] = config['dataset']['crop_size']
# create model + dataset
model = CycleGAN(opt, config)
model.general_setup()
dataset = CycleDataset(
opt.to_train,
dataroot=opt.dataroot,
**config['dataset']
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=config['dataset']['num_workers']
)
print(f"Saving images in {result_dir}")
for i, data in enumerate(dataloader):
if i > opt.num_tests:
break
model.setup_input(data)
out = model.test() # default is real image and style transferred image
# save in results dir
file_name = os.path.splitext(os.path.basename(model.image_paths[0]))[0]
save_outs(out, os.path.join(result_dir, file_name), opt.save_separate)