forked from Sohl-Dickstein/Diffusion-Probabilistic-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinpaint_from_model.py
executable file
·130 lines (95 loc) · 4.51 KB
/
inpaint_from_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
"""
Use trained model to apply inpainting to dataset with missing data.
Dataset details:
- Dataset should be saved in the Caffe image data format- image files and index file with path
of each image.
- Missing data mask information must also be supplied, also in Caffe image data format. Mask
images should be boolean, true on missing pixels and false for observed pixels.
"""
import argparse
import numpy as np
import os
import sys
from os.path import join
import progressbar
import PIL.Image
from scipy.misc import toimage
from sampler import inpaint_masked_samples
from theano.misc import pkl_utils
def save_single_image(x, save_dir, save_name ):
toimage(x, cmin=0.0, cmax=1.0).save(join(save_dir,save_name))
def load_image(path, scale=255.0):
return np.float32(PIL.Image.open(path)) / scale
def load_images(images_dir, image_size, index_file, flatten=True):
if not os.path.exists(images_dir):
raise IOError('Error- %s doesn\'t exist!' % images_dir)
raw_im_data = np.loadtxt(os.path.join(images_dir,index_file),delimiter=' ',dtype=str)
total_images = raw_im_data.shape[0]
if flatten:
ims = np.zeros((total_images,np.product(image_size)))
else:
ims = np.zeros( (total_images , 1) + image_size )
for idx in np.arange(total_images):
print ('loading image %d of %d \r' % (idx+1,total_images)),
sys.stdout.flush()
if flatten:
ims[idx,:] = load_image(os.path.join(images_dir,raw_im_data[idx][0])).reshape(np.product(image_size))
else:
ims[idx,0,:,:] = load_image(os.path.join(images_dir,raw_im_data[idx][0]))
print "\n"
return ims
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=100, type=int,
help='Batch size')
parser.add_argument('--resume_file', default=None, type=str,
help='Name of saved model to continue training', required=True)
parser.add_argument('--missing_dataset_path', default=None, type=str,
help='Path to dir containing index.txt and index_mask.txt')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
db_dir = args.missing_dataset_path
new_db_path = db_dir+'_dpm_ip'
if not os.path.exists(new_db_path):
os.mkdir(new_db_path)
batch_size = args.batch_size
mnist_size = (28,28)
load = pkl_utils.load
print "Resuming training from " + args.resume_file
with open(args.resume_file, "rb") as f:
main_loop = load(f)
plot_samples_ext = main_loop.extensions[7]
get_mu_sigma = plot_samples_ext.get_mu_sigma
dpm = plot_samples_ext.model
input_h = dpm.spatial_width
input_w = dpm.spatial_width
raw_im_data = np.loadtxt(os.path.join(db_dir,'index.txt'),delimiter=' ',dtype=str)
print "Loading images to inpaint..."
X = load_images(db_dir,(input_h,input_w),index_file='index.txt',flatten=False)
print "Loading masks..."
X_mask = load_images(db_dir,(input_h,input_w),index_file='index_mask.txt',flatten=False).astype(bool)
# Caluclated on whole dataset according to
# scl = 1./np.sqrt(np.mean((X-np.mean(X))**2))
# shft = -np.mean(X*scl)
# Same method as in original code except we're calculating on whole dataset and not just per minibatch
scl = 3.24154476773
shft = -0.42452
X_scale_shift = X * scl + shft
N = raw_im_data.shape[0]
n_batches = N / batch_size
pbar = progressbar.ProgressBar(widgets=[progressbar.FormatLabel('\rProcessed %(value)d of %(max)d Batches '), progressbar.Bar()], maxval=n_batches, term_width=50).start()
with open(join(new_db_path,'index.txt'),'wb') as db_file:
for b in np.arange(n_batches):
X_batch = X_scale_shift[b*batch_size:(b+1)*batch_size,:,:,:]
X_batch_mask = np.logical_not(X_mask[b*batch_size:(b+1)*batch_size,:,:,:].astype(bool))
X0 = inpaint_masked_samples(dpm, get_mu_sigma, X_batch, X_batch_mask.ravel())
for idx in np.arange(batch_size):
abs_idx = (b*batch_size)+idx
save_name = raw_im_data[abs_idx][0].replace('corrupted','ip')
save_single_image(X0[idx,0,:,:], new_db_path,save_name)
db_file.write('%s %s\n' % ( save_name, raw_im_data[abs_idx][1]))
pbar.update(b)
pbar.finish()