-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathinference.py
executable file
·95 lines (76 loc) · 3.26 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import random
import socket
import yaml
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import models
import datasets
import utils
from models import DenoisingDiffusion, DiffusiveRestoration
from PIL import Image
def parse_args_and_config():
parser = argparse.ArgumentParser(description='Restoring Weather with Patch-Based Denoising Diffusion Models')
parser.add_argument("--config", type=str, required=True,
help="Path to the config file")
parser.add_argument('--resume', default='', type=str,
help='Path for the diffusion model checkpoint to load for evaluation')
parser.add_argument("--sampling_timesteps", type=int, default=25,
help="Number of implicit sampling steps")
parser.add_argument("--eta", type=float, default=0,
help="Number of implicit sampling steps")
parser.add_argument('--seed', default=1234, type=int, metavar='N',
help='Seed for initializing training (default: 61)')
parser.add_argument("--condition_image", required=True, type=str,
help="Conditional Image")
args = parser.parse_args()
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
new_config = dict2namespace(config)
return args, new_config
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
def inverse_data_transform(X):
return torch.clamp((X + 1.0) / 2.0, 0.0, 1.0)
def main():
args, config = parse_args_and_config()
to_tensor = torchvision.transforms.ToTensor()
# setup device to run
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device: {}".format(device))
config.device = device
if torch.cuda.is_available():
print('Note: Currently supports evaluations (restoration) when run only on a single GPU!')
# set random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.benchmark = True
diffusion = DenoisingDiffusion(args, config)
diffusion.load_ddm_ckpt(args.resume, ema=True)
diffusion.model.eval()
with torch.no_grad():
# x_cond = x[:, :3, :, :].to(self.diffusion.device)
# x_output = self.diffusive_restoration(x_cond, r=r)
x_cond = Image.open(args.condition_image)
x_cond = x_cond.resize((config.data.image_size, config.data.image_size), Image.BICUBIC)
x_cond = to_tensor(x_cond).to(diffusion.device)
utils.logging.save_image(x_cond, f"results/input.png")
x_cond = x_cond[None, :, :, :]
# print(x_cond.size())
x = torch.randn(x_cond.size(), device=diffusion.device)
x_output = diffusion.sample_image(x_cond, x, eta=args.eta, patch_locs=None, patch_size=None)
x_output = inverse_data_transform(x_output)
utils.logging.save_image(x_output, f"results/output.png")
main()